From dcc882beb6283f0dcd62745b42c3106eead24f36 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Tue, 23 Sep 2025 15:29:49 -0400 Subject: [PATCH 01/19] Parquet Modular Encryption support for native readers using KMS on Spark side accessed via JNI. --- .../comet/parquet/CometFileKeyUnwrapper.java | 88 ++++++++++ .../java/org/apache/comet/parquet/Native.java | 5 +- .../comet/parquet/NativeBatchReader.java | 15 +- .../comet/objectstore/NativeConfig.scala | 2 +- native/Cargo.lock | 4 + native/core/Cargo.toml | 2 +- native/core/src/execution/jni_api.rs | 13 ++ native/core/src/execution/planner.rs | 2 + native/core/src/parquet/mod.rs | 25 ++- native/core/src/parquet/parquet_exec.rs | 161 +++++++++++++++++- native/proto/src/proto/operator.proto | 1 + .../org/apache/comet/CometExecIterator.scala | 30 +++- .../main/scala/org/apache/comet/Native.scala | 5 +- .../apache/comet/rules/CometScanRule.scala | 21 +-- .../apache/comet/serde/QueryPlanSerde.scala | 6 + .../shuffle/CometNativeShuffleWriter.scala | 3 +- .../apache/spark/sql/comet/operators.scala | 54 +++++- .../sql/comet/ParquetEncryptionITCase.scala | 35 ++-- 18 files changed, 423 insertions(+), 49 deletions(-) create mode 100644 common/src/main/java/org/apache/comet/parquet/CometFileKeyUnwrapper.java diff --git a/common/src/main/java/org/apache/comet/parquet/CometFileKeyUnwrapper.java b/common/src/main/java/org/apache/comet/parquet/CometFileKeyUnwrapper.java new file mode 100644 index 0000000000..ea4037dbf0 --- /dev/null +++ b/common/src/main/java/org/apache/comet/parquet/CometFileKeyUnwrapper.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.parquet; + +import java.util.concurrent.ConcurrentHashMap; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.parquet.crypto.DecryptionKeyRetriever; +import org.apache.parquet.crypto.DecryptionPropertiesFactory; +import org.apache.parquet.crypto.FileDecryptionProperties; +import org.apache.parquet.crypto.ParquetCryptoRuntimeException; + +/** + * Helper class to access DecryptionKeyRetriever.getKey from native code via JNI. This class handles + * the complexity of getting the proper Hadoop Configuration from the current Spark context and + * creating properly configured DecryptionKeyRetriever instances using DecryptionPropertiesFactory. + */ +public class CometFileKeyUnwrapper { + + // Each file path gets a unique DecryptionKeyRetriever + private final ConcurrentHashMap retrieverCache = + new ConcurrentHashMap<>(); + + // Each hadoopConf yields a unique DecryptionPropertiesFactory. While it's unlikely that + // this plan contains more than one hadoopConf, we don't want to assume that. So we'll + // provide the ability to cache more than one Factory with a map. + private final ConcurrentHashMap factoryCache = + new ConcurrentHashMap<>(); + + /** + * Creates and stores a DecryptionKeyRetriever instance for the given file path. + * + * @param filePath The path to the Parquet file + * @param hadoopConf The Hadoop Configuration to use for this file path + */ + public void storeDecryptionKeyRetriever(final String filePath, final Configuration hadoopConf) { + // Use DecryptionPropertiesFactory.loadFactory to get the factory and then call + // getFileDecryptionProperties + DecryptionPropertiesFactory factory = factoryCache.get(hadoopConf); + if (factory == null) { + factory = DecryptionPropertiesFactory.loadFactory(hadoopConf); + factoryCache.put(hadoopConf, factory); + } + Path path = new Path(filePath); + FileDecryptionProperties decryptionProperties = + factory.getFileDecryptionProperties(hadoopConf, path); + + DecryptionKeyRetriever keyRetriever = decryptionProperties.getKeyRetriever(); + retrieverCache.put(filePath, keyRetriever); + } + + /** + * Gets the decryption key for the given key metadata using the cached DecryptionKeyRetriever for + * the specified file path. + * + * @param filePath The path to the Parquet file + * @param keyMetadata The key metadata bytes from the Parquet file + * @return The decrypted key bytes + * @throws ParquetCryptoRuntimeException if key unwrapping fails + */ + public byte[] getKey(final String filePath, final byte[] keyMetadata) + throws ParquetCryptoRuntimeException { + DecryptionKeyRetriever keyRetriever = retrieverCache.get(filePath); + if (keyRetriever == null) { + throw new ParquetCryptoRuntimeException( + "Failed to find DecryptionKeyRetriever for path: " + filePath); + } + return keyRetriever.getKey(keyMetadata); + } +} diff --git a/common/src/main/java/org/apache/comet/parquet/Native.java b/common/src/main/java/org/apache/comet/parquet/Native.java index cceb1085c3..dbddc3b743 100644 --- a/common/src/main/java/org/apache/comet/parquet/Native.java +++ b/common/src/main/java/org/apache/comet/parquet/Native.java @@ -267,9 +267,11 @@ public static native long initRecordBatchReader( String sessionTimezone, int batchSize, boolean caseSensitive, - Map objectStoreOptions); + Map objectStoreOptions, + CometFileKeyUnwrapper keyUnwrapper); // arrow native version of read batch + /** * Read the next batch of data into memory on native side * @@ -280,6 +282,7 @@ public static native long initRecordBatchReader( // arrow native equivalent of currentBatch. 'columnNum' is number of the column in the record // batch + /** * Load the column corresponding to columnNum in the currently loaded record batch into JVM * diff --git a/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java b/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java index 67c2775400..1897889359 100644 --- a/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java +++ b/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java @@ -80,7 +80,7 @@ import org.apache.comet.vector.CometVector; import org.apache.comet.vector.NativeUtil; -import static scala.jdk.javaapi.CollectionConverters.*; +import static scala.jdk.javaapi.CollectionConverters.asJava; /** * A vectorized Parquet reader that reads a Parquet file in a batched fashion. @@ -410,6 +410,16 @@ public void init() throws Throwable { } } + boolean encryptionEnabled = + !conf.get("parquet.crypto.factory.class").isEmpty() + || !conf.get("parquet.encryption.kms.client.class", "").isEmpty(); + + // Create keyUnwrapper if encryption is enabled + CometFileKeyUnwrapper keyUnwrapper = encryptionEnabled ? new CometFileKeyUnwrapper() : null; + if (encryptionEnabled) { + keyUnwrapper.storeDecryptionKeyRetriever(file.filePath().toString(), conf); + } + int batchSize = conf.getInt( CometConf.COMET_BATCH_SIZE().key(), @@ -426,7 +436,8 @@ public void init() throws Throwable { timeZoneId, batchSize, caseSensitive, - objectStoreOptions); + objectStoreOptions, + keyUnwrapper); } isInitialized = true; } diff --git a/common/src/main/scala/org/apache/comet/objectstore/NativeConfig.scala b/common/src/main/scala/org/apache/comet/objectstore/NativeConfig.scala index b930aea17a..885b4686e7 100644 --- a/common/src/main/scala/org/apache/comet/objectstore/NativeConfig.scala +++ b/common/src/main/scala/org/apache/comet/objectstore/NativeConfig.scala @@ -58,7 +58,7 @@ object NativeConfig { def extractObjectStoreOptions(hadoopConf: Configuration, uri: URI): Map[String, String] = { val scheme = uri.getScheme.toLowerCase(Locale.ROOT) - import scala.collection.JavaConverters._ + import scala.jdk.CollectionConverters._ val options = scala.collection.mutable.Map[String, String]() // The schemes will use libhdfs diff --git a/native/Cargo.lock b/native/Cargo.lock index 24551efb40..eb1c9e8c53 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -1433,6 +1433,7 @@ dependencies = [ "datafusion-session", "datafusion-sql", "futures", + "hex", "itertools 0.14.0", "log", "object_store", @@ -1620,6 +1621,7 @@ dependencies = [ "chrono", "half", "hashbrown 0.14.5", + "hex", "indexmap", "libc", "log", @@ -1747,6 +1749,7 @@ dependencies = [ "datafusion-pruning", "datafusion-session", "futures", + "hex", "itertools 0.14.0", "log", "object_store", @@ -1777,6 +1780,7 @@ dependencies = [ "log", "object_store", "parking_lot", + "parquet", "rand", "tempfile", "url", diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml index 323b5d0338..295f5831b2 100644 --- a/native/core/Cargo.toml +++ b/native/core/Cargo.toml @@ -59,7 +59,7 @@ bytes = { workspace = true } tempfile = "3.8.0" itertools = "0.14.0" paste = "1.0.14" -datafusion = { workspace = true } +datafusion = { workspace = true, features = ["parquet_encryption"] } datafusion-spark = { workspace = true } once_cell = "1.18.0" regex = { workspace = true } diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 3446f42b24..a120e8a29a 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -75,6 +75,7 @@ use crate::execution::spark_plan::SparkPlan; use crate::execution::tracing::{log_memory_usage, trace_begin, trace_end, with_trace}; +use crate::parquet::parquet_exec::{CometEncryptionFactory, ENCRYPTION_FACTORY_ID}; use datafusion_comet_proto::spark_operator::operator::OpStruct; use log::info; use once_cell::sync::Lazy; @@ -167,6 +168,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( debug_native: jboolean, explain_native: jboolean, tracing_enabled: jboolean, + key_unwrapper_obj: JObject, ) -> jlong { try_unwrap_or_throw(&e, |mut env| { with_trace("createPlan", tracing_enabled != JNI_FALSE, || { @@ -239,6 +241,17 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( None }; + // Handle key unwrapper for encrypted files + if !key_unwrapper_obj.is_null() { + let encryption_factory = CometEncryptionFactory { + key_unwrapper: jni_new_global_ref!(env, key_unwrapper_obj)?, + }; + session.runtime_env().register_parquet_encryption_factory( + ENCRYPTION_FACTORY_ID, + Arc::new(encryption_factory), + ); + } + let exec_context = Box::new(ExecutionContext { id, task_attempt_id, diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 5aa6ece3bc..294ddcdc9c 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1360,6 +1360,8 @@ impl PhysicalPlanner { default_values, scan.session_timezone.as_str(), scan.case_sensitive, + self.session_ctx(), + scan.encryption_enabled, )?; Ok(( vec![], diff --git a/native/core/src/parquet/mod.rs b/native/core/src/parquet/mod.rs index a6efe4ed53..876ab488bb 100644 --- a/native/core/src/parquet/mod.rs +++ b/native/core/src/parquet/mod.rs @@ -52,8 +52,11 @@ use crate::execution::operators::ExecutionError; use crate::execution::planner::PhysicalPlanner; use crate::execution::serde; use crate::execution::utils::SparkArrowConvert; +use crate::jvm_bridge::{jni_new_global_ref, JVMClasses}; use crate::parquet::data_type::AsBytes; -use crate::parquet::parquet_exec::init_datasource_exec; +use crate::parquet::parquet_exec::{ + init_datasource_exec, CometEncryptionFactory, ENCRYPTION_FACTORY_ID, +}; use crate::parquet::parquet_support::prepare_object_store_with_configs; use arrow::array::{Array, RecordBatch}; use arrow::buffer::{Buffer, MutableBuffer}; @@ -712,8 +715,10 @@ pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_initRecordBat batch_size: jint, case_sensitive: jboolean, object_store_options: jobject, + key_unwrapper_obj: JObject, ) -> jlong { try_unwrap_or_throw(&e, |mut env| unsafe { + JVMClasses::init(&mut env); let session_config = SessionConfig::new().with_batch_size(batch_size as usize); let planner = PhysicalPlanner::new(Arc::new(SessionContext::new_with_config(session_config)), 0); @@ -766,6 +771,22 @@ pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_initRecordBat .unwrap() .into(); + // Handle key unwrapper for encrypted files + let encryption_enabled = if !key_unwrapper_obj.is_null() { + let encryption_factory = CometEncryptionFactory { + key_unwrapper: jni_new_global_ref!(env, key_unwrapper_obj)?, + }; + session_ctx + .runtime_env() + .register_parquet_encryption_factory( + ENCRYPTION_FACTORY_ID, + Arc::new(encryption_factory), + ); + true + } else { + false + }; + let scan = init_datasource_exec( required_schema, Some(data_schema), @@ -778,6 +799,8 @@ pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_initRecordBat None, session_timezone.as_str(), case_sensitive != JNI_FALSE, + session_ctx, + encryption_enabled, )?; let partition_index: usize = 0; diff --git a/native/core/src/parquet/parquet_exec.rs b/native/core/src/parquet/parquet_exec.rs index 4b587b7ba6..56f575dc2a 100644 --- a/native/core/src/parquet/parquet_exec.rs +++ b/native/core/src/parquet/parquet_exec.rs @@ -16,21 +16,31 @@ // under the License. use crate::execution::operators::ExecutionError; +use crate::jvm_bridge::JVMClasses; use crate::parquet::parquet_support::SparkParquetOptions; use crate::parquet::schema_adapter::SparkSchemaAdapterFactory; use arrow::datatypes::{Field, SchemaRef}; -use datafusion::config::TableParquetOptions; +use async_trait::async_trait; +use datafusion::common::extensions_options; +use datafusion::config::{EncryptionFactoryOptions, TableParquetOptions}; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::physical_plan::{ FileGroup, FileScanConfigBuilder, FileSource, ParquetSource, }; use datafusion::datasource::source::DataSourceExec; +use datafusion::error::DataFusionError; use datafusion::execution::object_store::ObjectStoreUrl; +use datafusion::execution::parquet_encryption::EncryptionFactory; use datafusion::physical_expr::expressions::BinaryExpr; use datafusion::physical_expr::PhysicalExpr; +use datafusion::prelude::SessionContext; use datafusion::scalar::ScalarValue; use datafusion_comet_spark_expr::EvalMode; use itertools::Itertools; +use jni::objects::{GlobalRef, JMethodID}; +use object_store::path::Path; +use parquet::encryption::decrypt::{FileDecryptionProperties, KeyRetriever}; +use parquet::encryption::encrypt::FileEncryptionProperties; use std::collections::HashMap; use std::sync::Arc; @@ -66,9 +76,16 @@ pub(crate) fn init_datasource_exec( default_values: Option>, session_timezone: &str, case_sensitive: bool, + session_ctx: &Arc, + encryption_enabled: bool, ) -> Result, ExecutionError> { - let (table_parquet_options, spark_parquet_options) = - get_options(session_timezone, case_sensitive); + let (table_parquet_options, spark_parquet_options) = get_options( + session_timezone, + case_sensitive, + &object_store_url, + encryption_enabled, + ); + let mut parquet_source = ParquetSource::new(table_parquet_options); // Create a conjunctive form of the vector because ParquetExecBuilder takes @@ -87,6 +104,12 @@ pub(crate) fn init_datasource_exec( } } + parquet_source = parquet_source.with_encryption_factory( + session_ctx + .runtime_env() + .parquet_encryption_factory(ENCRYPTION_FACTORY_ID)?, + ); + let file_source = parquet_source.with_schema_adapter_factory(Arc::new( SparkSchemaAdapterFactory::new(spark_parquet_options, default_values), ))?; @@ -122,9 +145,131 @@ pub(crate) fn init_datasource_exec( Ok(Arc::new(DataSourceExec::new(Arc::new(file_scan_config)))) } +pub const ENCRYPTION_FACTORY_ID: &str = "comet.jni_kms_encryption"; + +// Options used to configure our example encryption factory +extensions_options! { + struct CometEncryptionConfig { + url_base: String, default = "file:///".into() + } +} +#[derive(Debug)] +pub struct CometEncryptionFactory { + pub(crate) key_unwrapper: GlobalRef, +} + +/// `EncryptionFactory` is a DataFusion trait for types that generate +/// file encryption and decryption properties. +#[async_trait] +impl EncryptionFactory for CometEncryptionFactory { + async fn get_file_encryption_properties( + &self, + _options: &EncryptionFactoryOptions, + _schema: &SchemaRef, + _file_path: &Path, + ) -> Result, DataFusionError> { + Err(DataFusionError::NotImplemented( + "Comet does not support Parquet encryption yet." + .parse() + .unwrap(), + )) + } + + /// Generate file decryption properties to use when reading a Parquet file. + /// Rather than provide the AES keys directly for decryption, we set a `KeyRetriever` + /// that can determine the keys using the encryption metadata. + async fn get_file_decryption_properties( + &self, + options: &EncryptionFactoryOptions, + file_path: &Path, + ) -> Result, DataFusionError> { + let config: CometEncryptionConfig = options.to_extension_options()?; + + let full_path: String = config.url_base + file_path.as_ref(); + let key_retriever = CometKeyRetriever::new(&full_path, self.key_unwrapper.clone()) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + let decryption_properties = + FileDecryptionProperties::with_key_retriever(Arc::new(key_retriever)).build()?; + Ok(Some(decryption_properties)) + } +} + +struct CometKeyRetriever { + file_path: String, + key_unwrapper: GlobalRef, + get_key_method_id: JMethodID, +} + +impl CometKeyRetriever { + fn new(file_path: &str, key_unwrapper: GlobalRef) -> Result { + // Get JNI environment + let mut env = JVMClasses::get_env()?; + + Ok(CometKeyRetriever { + file_path: file_path.to_string(), + key_unwrapper, + get_key_method_id: env + .get_method_id( + "org/apache/comet/parquet/CometFileKeyUnwrapper", + "getKey", + "(Ljava/lang/String;[B)[B", + ) + .unwrap(), + }) + } +} + +impl KeyRetriever for CometKeyRetriever { + /// Get a data encryption key using the metadata stored in the Parquet file. + fn retrieve_key(&self, key_metadata: &[u8]) -> datafusion::parquet::errors::Result> { + use jni::{objects::JObject, signature::ReturnType}; + + // Get JNI environment + let mut env = JVMClasses::get_env() + .map_err(|e| datafusion::parquet::errors::ParquetError::General(e.to_string()))?; + + // Get the key unwrapper instance from GlobalRef + let unwrapper_instance = self.key_unwrapper.as_obj(); + + let instance: JObject = unsafe { JObject::from_raw(unwrapper_instance.as_raw()) }; + + // Convert file path to JString + let file_path_jstring = env.new_string(&self.file_path).unwrap(); + + // Convert key_metadata to JByteArray + let key_metadata_array = env.byte_array_from_slice(key_metadata).unwrap(); + + // Call instance method FileKeyUnwrapper.getKey(String, byte[]) -> byte[] + let result = unsafe { + env.call_method_unchecked( + instance, + self.get_key_method_id, + ReturnType::Array, + &[ + jni::objects::JValue::from(&file_path_jstring).as_jni(), + jni::objects::JValue::from(&key_metadata_array).as_jni(), + ], + ) + }; + + let result = result.unwrap(); + + // Extract the byte array from the result + let result_array = result.l().unwrap(); + + // Convert JObject to JByteArray and then to Vec + let byte_array: jni::objects::JByteArray = result_array.into(); + + let result_vec = env.convert_byte_array(&byte_array).unwrap(); + Ok(result_vec) + } +} + fn get_options( session_timezone: &str, case_sensitive: bool, + object_store_url: &ObjectStoreUrl, + encryption_enabled: bool, ) -> (TableParquetOptions, SparkParquetOptions) { let mut table_parquet_options = TableParquetOptions::new(); table_parquet_options.global.pushdown_filters = true; @@ -134,6 +279,16 @@ fn get_options( SparkParquetOptions::new(EvalMode::Legacy, session_timezone, false); spark_parquet_options.allow_cast_unsigned_ints = true; spark_parquet_options.case_sensitive = case_sensitive; + + if encryption_enabled { + table_parquet_options.crypto.configure_factory( + ENCRYPTION_FACTORY_ID, + &CometEncryptionConfig { + url_base: object_store_url.to_string(), + }, + ); + } + (table_parquet_options, spark_parquet_options) } diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index 57e012b369..a243ab6b03 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -104,6 +104,7 @@ message NativeScan { // configuration value "spark.hadoop.fs.s3a.access.key" will be stored as "fs.s3a.access.key" in // the map. map object_store_options = 13; + bool encryption_enabled = 14; } message Projection { diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index 67d044f8c5..69bdd39774 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -24,14 +24,18 @@ import java.lang.management.ManagementFactory import scala.util.matching.Regex +import org.apache.hadoop.conf.Configuration import org.apache.spark._ +import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.comet.CometMetricNode import org.apache.spark.sql.vectorized._ +import org.apache.spark.util.SerializableConfiguration -import org.apache.comet.CometConf.{COMET_BATCH_SIZE, COMET_DEBUG_ENABLED, COMET_EXEC_MEMORY_POOL_TYPE, COMET_EXPLAIN_NATIVE_ENABLED, COMET_METRICS_UPDATE_INTERVAL} +import org.apache.comet.CometConf._ import org.apache.comet.Tracing.withTrace +import org.apache.comet.parquet.CometFileKeyUnwrapper import org.apache.comet.serde.Config.ConfigMap import org.apache.comet.vector.NativeUtil @@ -52,6 +56,8 @@ import org.apache.comet.vector.NativeUtil * The number of partitions. * @param partitionIndex * The index of the partition. + * @param encryptedFilePaths + * Paths to encrypted Parquet files that need key unwrapping. */ class CometExecIterator( val id: Long, @@ -60,7 +66,8 @@ class CometExecIterator( protobufQueryPlan: Array[Byte], nativeMetrics: CometMetricNode, numParts: Int, - partitionIndex: Int) + partitionIndex: Int, + encryptedFilePaths: Seq[(String, Broadcast[SerializableConfiguration])] = Seq.empty) extends Iterator[ColumnarBatch] with Logging { @@ -72,6 +79,7 @@ class CometExecIterator( private val cometBatchIterators = inputs.map { iterator => new CometBatchIterator(iterator, nativeUtil) }.toArray + private val plan = { val conf = SparkEnv.get.conf val localDiskDirs = SparkEnv.get.blockManager.getLocalDiskDirs @@ -93,6 +101,20 @@ class CometExecIterator( } val protobufSparkConfigs = builder.build().toByteArray + // Create keyUnwrapper if encryption is enabled + val keyUnwrapper = if (encryptedFilePaths.nonEmpty) { + val unwrapper = new CometFileKeyUnwrapper() + + encryptedFilePaths.foreach { case (filePath, broadcastedConf) => + val hadoopConf: Configuration = broadcastedConf.value.value + unwrapper.storeDecryptionKeyRetriever(filePath, hadoopConf) + } + + unwrapper + } else { + null + } + nativeLib.createPlan( id, cometBatchIterators, @@ -111,7 +133,8 @@ class CometExecIterator( taskAttemptId = TaskContext.get().taskAttemptId, debug = COMET_DEBUG_ENABLED.get(), explain = COMET_EXPLAIN_NATIVE_ENABLED.get(), - tracingEnabled) + tracingEnabled, + keyUnwrapper) } private var nextBatch: Option[ColumnarBatch] = None @@ -135,6 +158,7 @@ class CometExecIterator( def convertToInt(threads: String): Int = { if (threads == "*") Runtime.getRuntime.availableProcessors() else threads.toInt } + val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+|\*)\s*,\s*([0-9]+)\]""".r val master = conf.get("spark.master") diff --git a/spark/src/main/scala/org/apache/comet/Native.scala b/spark/src/main/scala/org/apache/comet/Native.scala index 7430a4322c..2337c34e55 100644 --- a/spark/src/main/scala/org/apache/comet/Native.scala +++ b/spark/src/main/scala/org/apache/comet/Native.scala @@ -24,6 +24,8 @@ import java.nio.ByteBuffer import org.apache.spark.CometTaskMemoryManager import org.apache.spark.sql.comet.CometMetricNode +import org.apache.comet.parquet.CometFileKeyUnwrapper + class Native extends NativeBase { // scalastyle:off @@ -68,7 +70,8 @@ class Native extends NativeBase { taskAttemptId: Long, debug: Boolean, explain: Boolean, - tracingEnabled: Boolean): Long + tracingEnabled: Boolean, + keyUnwrapper: CometFileKeyUnwrapper): Long // scalastyle:on /** diff --git a/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala index cbca7304d2..25a78cc329 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala @@ -46,12 +46,14 @@ import org.apache.comet.CometSparkSessionExtensions.{isCometLoaded, isCometScanE import org.apache.comet.DataTypeSupport.isComplexType import org.apache.comet.objectstore.NativeConfig import org.apache.comet.parquet.{CometParquetScan, Native, SupportsComet} +import org.apache.comet.parquet.CometFileKeyUnwrapper import org.apache.comet.shims.CometTypeShim /** * Spark physical optimizer rule for replacing Spark scans with Comet scans. */ case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] with CometTypeShim { + import CometScanRule._ private lazy val showTransformations = CometConf.COMET_EXPLAIN_TRANSFORMATIONS.get() @@ -144,22 +146,11 @@ case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] with Com return withInfos(scanExec, fallbackReasons.toSet) } - val encryptionEnabled: Boolean = - conf.getConfString("parquet.crypto.factory.class", "").nonEmpty && - conf.getConfString("parquet.encryption.kms.client.class", "").nonEmpty - var scanImpl = COMET_NATIVE_SCAN_IMPL.get() // if scan is auto then pick the best available scan if (scanImpl == SCAN_AUTO) { - if (encryptionEnabled) { - logInfo( - s"Auto scan mode falling back to $SCAN_NATIVE_COMET because " + - s"$SCAN_NATIVE_ICEBERG_COMPAT does not support reading encrypted Parquet files") - scanImpl = SCAN_NATIVE_COMET - } else { - scanImpl = selectScan(scanExec, r.partitionSchema) - } + scanImpl = selectScan(scanExec, r.partitionSchema) } if (scanImpl == SCAN_NATIVE_DATAFUSION && !COMET_EXEC_ENABLED.get()) { @@ -206,12 +197,6 @@ case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] with Com return withInfos(scanExec, fallbackReasons.toSet) } - if (scanImpl != CometConf.SCAN_NATIVE_COMET && encryptionEnabled) { - fallbackReasons += - "Full native scan disabled because encryption is not supported" - return withInfos(scanExec, fallbackReasons.toSet) - } - val typeChecker = CometScanTypeChecker(scanImpl) val schemaSupported = typeChecker.isSchemaSupported(scanExec.requiredSchema, fallbackReasons) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 258d275e5b..4f7c03a0d4 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1153,6 +1153,12 @@ object QueryPlanSerde extends Logging with CometExprShim { // Collect S3/cloud storage configurations val hadoopConf = scan.relation.sparkSession.sessionState .newHadoopConfWithOptions(scan.relation.options) + val encryptionEnabled: Boolean = hadoopConf + .get("parquet.crypto.factory.class") + .nonEmpty || hadoopConf.get("parquet.encryption.kms.client.class", "").nonEmpty + + nativeScanBuilder.setEncryptionEnabled(encryptionEnabled) + firstPartition.foreach { partitionFile => val objectStoreOptions = NativeConfig.extractObjectStoreOptions(hadoopConf, partitionFile.pathUri) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala index 018c9f7c10..6a80410b3b 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala @@ -101,7 +101,8 @@ class CometNativeShuffleWriter[K, V]( nativePlan, nativeMetrics, numParts, - context.partitionId()) + context.partitionId(), + encryptedFilePaths = Seq.empty) while (cometIter.hasNext) { cometIter.next() diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index a7cfacc475..2b425fa75b 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -25,22 +25,23 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.TaskContext +import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, HashPartitioningLike, Partitioning, PartitioningCollection, UnknownPartitioning} +import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.comet.util.Utils -import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec, ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan, UnaryExecNode} -import org.apache.spark.sql.execution.PartitioningPreservingUnaryExecNode +import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.SerializableConfiguration import org.apache.spark.util.io.ChunkedByteBuffer import com.google.common.base.Objects @@ -114,7 +115,8 @@ object CometExec { nativePlan, CometMetricNode(Map.empty), numParts, - partitionIdx) + partitionIdx, + encryptedFilePaths = Seq.empty) } def getCometIterator( @@ -123,7 +125,9 @@ object CometExec { nativePlan: Operator, nativeMetrics: CometMetricNode, numParts: Int, - partitionIdx: Int): CometExecIterator = { + partitionIdx: Int, + encryptedFilePaths: Seq[(String, Broadcast[SerializableConfiguration])]) + : CometExecIterator = { val outputStream = new ByteArrayOutputStream() nativePlan.writeTo(outputStream) outputStream.close() @@ -135,7 +139,8 @@ object CometExec { bytes, nativeMetrics, numParts, - partitionIdx) + partitionIdx, + encryptedFilePaths) } /** @@ -201,6 +206,32 @@ abstract class CometNativeExec extends CometExec { // TODO: support native metrics for all operators. val nativeMetrics = CometMetricNode.fromCometPlan(this) + // For each relation in a CometNativeScan generate a hadoopConf, + // for each file path in a relation associate with hadoopConf + val cometNativeScans: Seq[CometNativeScanExec] = this + .collectLeaves() + .filter(_.isInstanceOf[CometNativeScanExec]) + .map(_.asInstanceOf[CometNativeScanExec]) + val encryptedFilePaths = cometNativeScans.flatMap { scan => + // This creates a hadoopConf that brings in any SQLConf "spark.hadoop.*" configs and + // per-relation configs since different tables might have different decryption + // properties. + val hadoopConf = scan.relation.sparkSession.sessionState + .newHadoopConfWithOptions(scan.relation.options) + val encryptionEnabled: Boolean = + hadoopConf.get("parquet.crypto.factory.class").nonEmpty && + hadoopConf.get("parquet.encryption.kms.client.class", "").nonEmpty + if (encryptionEnabled) { + // hadoopConf isn't serializable, so we have to do a broadcasted config. + val broadcastedConf = + scan.relation.sparkSession.sparkContext + .broadcast(new SerializableConfiguration(hadoopConf)) + scan.relation.inputFiles.map { filePath => (filePath, broadcastedConf) } + } else { + Seq.empty + } + } + def createCometExecIter( inputs: Seq[Iterator[ColumnarBatch]], numParts: Int, @@ -212,7 +243,8 @@ abstract class CometNativeExec extends CometExec { serializedPlanCopy, nativeMetrics, numParts, - partitionIndex) + partitionIndex, + encryptedFilePaths) setSubqueries(it.id, this) @@ -429,6 +461,7 @@ abstract class CometBinaryExec extends CometNativeExec with BinaryExecNode */ case class SerializedPlan(plan: Option[Array[Byte]]) { def isDefined: Boolean = plan.isDefined + def isEmpty: Boolean = plan.isEmpty } @@ -442,6 +475,7 @@ case class CometProjectExec( extends CometUnaryExec with PartitioningPreservingUnaryExecNode { override def producedAttributes: AttributeSet = outputSet + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = this.copy(child = newChild) @@ -474,6 +508,7 @@ case class CometFilterExec( extends CometUnaryExec { override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = child.outputOrdering override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = @@ -551,7 +586,9 @@ case class CometLocalLimitExec( extends CometUnaryExec { override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = child.outputOrdering override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = @@ -585,7 +622,9 @@ case class CometGlobalLimitExec( extends CometUnaryExec { override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = child.outputOrdering override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = @@ -985,6 +1024,7 @@ case class CometScanWrapper(override val nativeOp: Operator, override val origin extends CometNativeExec with LeafExecNode { override val serializedPlanOpt: SerializedPlan = SerializedPlan(None) + override def stringArgs: Iterator[Any] = Iterator(originalPlan.output, originalPlan) } diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/ParquetEncryptionITCase.scala b/spark/src/test/scala/org/apache/spark/sql/comet/ParquetEncryptionITCase.scala index 8d2c3db727..45b89735e5 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/ParquetEncryptionITCase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/ParquetEncryptionITCase.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.comet.{CometConf, IntegrationTestSuite} +import org.apache.comet.CometConf.{SCAN_NATIVE_COMET, SCAN_NATIVE_DATAFUSION, SCAN_NATIVE_ICEBERG_COMPAT} /** * A integration test suite that tests parquet modular encryption usage. @@ -49,8 +50,6 @@ class ParquetEncryptionITCase extends CometTestBase with SQLTestUtils { private val key2 = encoder.encodeToString("1234567890123451".getBytes(StandardCharsets.UTF_8)) test("SPARK-34990: Write and read an encrypted parquet") { - assume(CometConf.COMET_NATIVE_SCAN_IMPL.get() != CometConf.SCAN_NATIVE_DATAFUSION) - assume(CometConf.COMET_NATIVE_SCAN_IMPL.get() != CometConf.SCAN_NATIVE_ICEBERG_COMPAT) import testImplicits._ @@ -93,8 +92,6 @@ class ParquetEncryptionITCase extends CometTestBase with SQLTestUtils { } test("SPARK-37117: Can't read files in Parquet encryption external key material mode") { - assume(CometConf.COMET_NATIVE_SCAN_IMPL.get() != CometConf.SCAN_NATIVE_DATAFUSION) - assume(CometConf.COMET_NATIVE_SCAN_IMPL.get() != CometConf.SCAN_NATIVE_ICEBERG_COMPAT) import testImplicits._ @@ -146,13 +143,29 @@ class ParquetEncryptionITCase extends CometTestBase with SQLTestUtils { override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit pos: Position): Unit = { + Seq("true", "false").foreach { cometEnabled => - super.test(testName + s" Comet($cometEnabled)", testTags: _*) { - withSQLConf( - CometConf.COMET_ENABLED.key -> cometEnabled, - CometConf.COMET_EXEC_ENABLED.key -> "true", - SQLConf.ANSI_ENABLED.key -> "true") { - testFun + if (cometEnabled == "true") { + Seq(SCAN_NATIVE_COMET, SCAN_NATIVE_DATAFUSION, SCAN_NATIVE_ICEBERG_COMPAT).foreach { + scanImpl => + super.test(testName + s" Comet($cometEnabled)" + s"Scan($scanImpl)", testTags: _*) { + withSQLConf( + CometConf.COMET_ENABLED.key -> cometEnabled, + CometConf.COMET_EXEC_ENABLED.key -> "true", + SQLConf.ANSI_ENABLED.key -> "false", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> scanImpl) { + testFun + } + } + } + } else { + super.test(testName + s" Comet($cometEnabled)", testTags: _*) { + withSQLConf( + CometConf.COMET_ENABLED.key -> cometEnabled, + CometConf.COMET_EXEC_ENABLED.key -> "true", + SQLConf.ANSI_ENABLED.key -> "false") { + testFun + } } } } @@ -164,7 +177,9 @@ class ParquetEncryptionITCase extends CometTestBase with SQLTestUtils { } private var _spark: SparkSessionType = _ + protected implicit override def spark: SparkSessionType = _spark + protected implicit override def sqlContext: SQLContext = _spark.sqlContext /** From 8bac76a1cc2744c8b227decf32e71ecd3ddb2468 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Tue, 23 Sep 2025 15:44:24 -0400 Subject: [PATCH 02/19] Fix unused import. --- spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala index 25a78cc329..e8f9e7cdb3 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala @@ -46,7 +46,6 @@ import org.apache.comet.CometSparkSessionExtensions.{isCometLoaded, isCometScanE import org.apache.comet.DataTypeSupport.isComplexType import org.apache.comet.objectstore.NativeConfig import org.apache.comet.parquet.{CometParquetScan, Native, SupportsComet} -import org.apache.comet.parquet.CometFileKeyUnwrapper import org.apache.comet.shims.CometTypeShim /** From 40935df19ee74785d8af82089db4a1df27868910 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Tue, 23 Sep 2025 18:06:25 -0400 Subject: [PATCH 03/19] Fix encryptionEnabled check in NativeBatchReader.java, and guard encryption factory registration in parquet_exec.rs. --- .../org/apache/comet/parquet/NativeBatchReader.java | 6 ++++-- native/core/src/parquet/parquet_exec.rs | 12 +++++++----- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java b/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java index 1897889359..d4de5f9bd6 100644 --- a/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java +++ b/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java @@ -411,8 +411,10 @@ public void init() throws Throwable { } boolean encryptionEnabled = - !conf.get("parquet.crypto.factory.class").isEmpty() - || !conf.get("parquet.encryption.kms.client.class", "").isEmpty(); + (conf.get("parquet.crypto.factory.class") != null + && !conf.get("parquet.crypto.factory.class").isEmpty()) + || (conf.get("parquet.encryption.kms.client.class") != null + && !conf.get("parquet.encryption.kms.client.class").isEmpty()); // Create keyUnwrapper if encryption is enabled CometFileKeyUnwrapper keyUnwrapper = encryptionEnabled ? new CometFileKeyUnwrapper() : null; diff --git a/native/core/src/parquet/parquet_exec.rs b/native/core/src/parquet/parquet_exec.rs index 56f575dc2a..cdeefb3dfb 100644 --- a/native/core/src/parquet/parquet_exec.rs +++ b/native/core/src/parquet/parquet_exec.rs @@ -104,11 +104,13 @@ pub(crate) fn init_datasource_exec( } } - parquet_source = parquet_source.with_encryption_factory( - session_ctx - .runtime_env() - .parquet_encryption_factory(ENCRYPTION_FACTORY_ID)?, - ); + if encryption_enabled { + parquet_source = parquet_source.with_encryption_factory( + session_ctx + .runtime_env() + .parquet_encryption_factory(ENCRYPTION_FACTORY_ID)?, + ); + } let file_source = parquet_source.with_schema_adapter_factory(Arc::new( SparkSchemaAdapterFactory::new(spark_parquet_options, default_values), From 7cbfb1b8c35899da28f11d953e03ae8f0682debf Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Tue, 23 Sep 2025 18:27:13 -0400 Subject: [PATCH 04/19] Fix NPE when checking encryptedEnabled. --- .../scala/org/apache/comet/serde/QueryPlanSerde.scala | 9 +++++++-- .../scala/org/apache/spark/sql/comet/operators.scala | 10 +++++++--- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 4f7c03a0d4..52549b31a6 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1153,9 +1153,13 @@ object QueryPlanSerde extends Logging with CometExprShim { // Collect S3/cloud storage configurations val hadoopConf = scan.relation.sparkSession.sessionState .newHadoopConfWithOptions(scan.relation.options) - val encryptionEnabled: Boolean = hadoopConf + val encryptionEnabled: Boolean = (hadoopConf + .get("parquet.crypto.factory.class") != null && hadoopConf .get("parquet.crypto.factory.class") - .nonEmpty || hadoopConf.get("parquet.encryption.kms.client.class", "").nonEmpty + .nonEmpty) || (hadoopConf + .get("parquet.encryption.kms.client.class") != null && hadoopConf + .get("parquet.encryption.kms.client.class") + .nonEmpty) nativeScanBuilder.setEncryptionEnabled(encryptionEnabled) @@ -1700,6 +1704,7 @@ object QueryPlanSerde extends Logging with CometExprShim { } // scalastyle:off + /** * Align w/ Arrow's * [[https://github.com/apache/arrow-rs/blob/55.2.0/arrow-ord/src/rank.rs#L30-L40 can_rank]] and diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 2b425fa75b..8e0fee594a 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -218,9 +218,13 @@ abstract class CometNativeExec extends CometExec { // properties. val hadoopConf = scan.relation.sparkSession.sessionState .newHadoopConfWithOptions(scan.relation.options) - val encryptionEnabled: Boolean = - hadoopConf.get("parquet.crypto.factory.class").nonEmpty && - hadoopConf.get("parquet.encryption.kms.client.class", "").nonEmpty + val encryptionEnabled: Boolean = (hadoopConf + .get("parquet.crypto.factory.class") != null && hadoopConf + .get("parquet.crypto.factory.class") + .nonEmpty) || (hadoopConf + .get("parquet.encryption.kms.client.class") != null && hadoopConf + .get("parquet.encryption.kms.client.class") + .nonEmpty) if (encryptionEnabled) { // hadoopConf isn't serializable, so we have to do a broadcasted config. val broadcastedConf = From 090497b787b6854af88344e212f500196140f027 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Wed, 24 Sep 2025 15:15:40 -0400 Subject: [PATCH 05/19] Minor refactor for encryptionEnabled. --- .../org/apache/comet/parquet/NativeBatchReader.java | 6 +----- .../org/apache/comet/parquet/CometParquetUtils.scala | 11 +++++++++++ .../org/apache/comet/serde/QueryPlanSerde.scala | 12 +++--------- .../scala/org/apache/spark/sql/comet/operators.scala | 9 ++------- 4 files changed, 17 insertions(+), 21 deletions(-) diff --git a/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java b/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java index d4de5f9bd6..1b86b6e3a5 100644 --- a/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java +++ b/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java @@ -410,11 +410,7 @@ public void init() throws Throwable { } } - boolean encryptionEnabled = - (conf.get("parquet.crypto.factory.class") != null - && !conf.get("parquet.crypto.factory.class").isEmpty()) - || (conf.get("parquet.encryption.kms.client.class") != null - && !conf.get("parquet.encryption.kms.client.class").isEmpty()); + boolean encryptionEnabled = CometParquetUtils.encryptionEnabled(conf); // Create keyUnwrapper if encryption is enabled CometFileKeyUnwrapper keyUnwrapper = encryptionEnabled ? new CometFileKeyUnwrapper() : null; diff --git a/common/src/main/scala/org/apache/comet/parquet/CometParquetUtils.scala b/common/src/main/scala/org/apache/comet/parquet/CometParquetUtils.scala index a37ec7e66a..d245b9734b 100644 --- a/common/src/main/scala/org/apache/comet/parquet/CometParquetUtils.scala +++ b/common/src/main/scala/org/apache/comet/parquet/CometParquetUtils.scala @@ -20,6 +20,8 @@ package org.apache.comet.parquet import org.apache.hadoop.conf.Configuration +import org.apache.parquet.crypto.DecryptionPropertiesFactory +import org.apache.parquet.crypto.keytools.KeyToolkit import org.apache.spark.sql.internal.SQLConf object CometParquetUtils { @@ -38,4 +40,13 @@ object CometParquetUtils { def ignoreMissingIds(conf: SQLConf): Boolean = conf.getConfString(IGNORE_MISSING_PARQUET_FIELD_ID, "false").toBoolean + + def encryptionEnabled(hadoopConf: Configuration): Boolean = { + // TODO: Are there any other properties to check? + val encryptionKeys = Seq( + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME) + + encryptionKeys.exists(key => Option(hadoopConf.get(key)).exists(_.nonEmpty)) + } } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 52549b31a6..dde35a4a70 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -48,6 +48,7 @@ import org.apache.comet.{CometConf, ConfigEntry} import org.apache.comet.CometSparkSessionExtensions.{isCometScan, withInfo} import org.apache.comet.expressions._ import org.apache.comet.objectstore.NativeConfig +import org.apache.comet.parquet.CometParquetUtils import org.apache.comet.serde.ExprOuterClass.{AggExpr, Expr, ScalarFunc} import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, BuildSide, JoinType, Operator} import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto} @@ -1153,15 +1154,8 @@ object QueryPlanSerde extends Logging with CometExprShim { // Collect S3/cloud storage configurations val hadoopConf = scan.relation.sparkSession.sessionState .newHadoopConfWithOptions(scan.relation.options) - val encryptionEnabled: Boolean = (hadoopConf - .get("parquet.crypto.factory.class") != null && hadoopConf - .get("parquet.crypto.factory.class") - .nonEmpty) || (hadoopConf - .get("parquet.encryption.kms.client.class") != null && hadoopConf - .get("parquet.encryption.kms.client.class") - .nonEmpty) - - nativeScanBuilder.setEncryptionEnabled(encryptionEnabled) + + nativeScanBuilder.setEncryptionEnabled(CometParquetUtils.encryptionEnabled(hadoopConf)) firstPartition.foreach { partitionFile => val objectStoreOptions = diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 8e0fee594a..b4709cecdf 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -47,6 +47,7 @@ import org.apache.spark.util.io.ChunkedByteBuffer import com.google.common.base.Objects import org.apache.comet.{CometConf, CometExecIterator, CometRuntimeException} +import org.apache.comet.parquet.CometParquetUtils import org.apache.comet.serde.OperatorOuterClass.Operator /** @@ -218,13 +219,7 @@ abstract class CometNativeExec extends CometExec { // properties. val hadoopConf = scan.relation.sparkSession.sessionState .newHadoopConfWithOptions(scan.relation.options) - val encryptionEnabled: Boolean = (hadoopConf - .get("parquet.crypto.factory.class") != null && hadoopConf - .get("parquet.crypto.factory.class") - .nonEmpty) || (hadoopConf - .get("parquet.encryption.kms.client.class") != null && hadoopConf - .get("parquet.encryption.kms.client.class") - .nonEmpty) + val encryptionEnabled = CometParquetUtils.encryptionEnabled(hadoopConf) if (encryptionEnabled) { // hadoopConf isn't serializable, so we have to do a broadcasted config. val broadcastedConf = From c9dfdd5df3d59ef3dd8aeb2ed8bd479ce7692568 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Fri, 26 Sep 2025 15:22:03 -0400 Subject: [PATCH 06/19] More tests. --- .../comet/parquet/CometParquetUtils.scala | 48 +++ .../apache/comet/rules/CometScanRule.scala | 22 +- .../sql/comet/ParquetEncryptionITCase.scala | 300 +++++++++++++++++- 3 files changed, 350 insertions(+), 20 deletions(-) diff --git a/common/src/main/scala/org/apache/comet/parquet/CometParquetUtils.scala b/common/src/main/scala/org/apache/comet/parquet/CometParquetUtils.scala index d245b9734b..56d9a7dcfd 100644 --- a/common/src/main/scala/org/apache/comet/parquet/CometParquetUtils.scala +++ b/common/src/main/scala/org/apache/comet/parquet/CometParquetUtils.scala @@ -29,6 +29,21 @@ object CometParquetUtils { private val PARQUET_FIELD_ID_READ_ENABLED = "spark.sql.parquet.fieldId.read.enabled" private val IGNORE_MISSING_PARQUET_FIELD_ID = "spark.sql.parquet.fieldId.read.ignoreMissing" + // Map of unsupported encryption configuration key-value pairs + private val UNSUPPORTED_ENCRYPTION_CONFIGS: Map[String, Set[String]] = Map( + "parquet.encryption.algorithm" -> Set("AES_GCM_CTR_V1") + // Add more unsupported configs here as needed + // "parquet.encryption.some.config" -> Set("unsupported_value1", "unsupported_value2") + ) + + // Map of encryption configurations that can only have specific allowed values + private val SUPPORTED_ENCRYPTION_CONFIGS_WHITELIST: Map[String, Set[String]] = Map( + "parquet.encryption.data.key.length.bits" -> Set("128"), + "parquet.encryption.kek.length.bits" -> Set("128") + // Add more whitelisted configs here as needed + // "parquet.encryption.some.config" -> Set("allowed_value1", "allowed_value2") + ) + def writeFieldId(conf: SQLConf): Boolean = conf.getConfString(PARQUET_FIELD_ID_WRITE_ENABLED, "false").toBoolean @@ -41,6 +56,39 @@ object CometParquetUtils { def ignoreMissingIds(conf: SQLConf): Boolean = conf.getConfString(IGNORE_MISSING_PARQUET_FIELD_ID, "false").toBoolean + /** + * Checks if the given Hadoop configuration contains any unsupported encryption settings. + * + * @param hadoopConf + * The Hadoop configuration to check + * @return + * true if all encryption configurations are supported, false if any unsupported config is + * found + */ + def isEncryptionConfigSupported(hadoopConf: Configuration): Boolean = { + // Check blacklist: configurations that should never have certain values + val blacklistCheck = UNSUPPORTED_ENCRYPTION_CONFIGS.forall { + case (configKey, unsupportedValues) => + val configValue = Option(hadoopConf.get(configKey)) + configValue match { + case Some(value) => !unsupportedValues.contains(value) + case None => true // Config not set, so it's supported + } + } + + // Check whitelist: configurations that can only have specific allowed values + val whitelistCheck = SUPPORTED_ENCRYPTION_CONFIGS_WHITELIST.forall { + case (configKey, allowedValues) => + val configValue = Option(hadoopConf.get(configKey)) + configValue match { + case Some(value) => allowedValues.contains(value) + case None => true // Config not set, so it's supported + } + } + + blacklistCheck && whitelistCheck + } + def encryptionEnabled(hadoopConf: Configuration): Boolean = { // TODO: Are there any other properties to check? val encryptionKeys = Seq( diff --git a/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala index e8f9e7cdb3..950d0e9d37 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala @@ -46,6 +46,7 @@ import org.apache.comet.CometSparkSessionExtensions.{isCometLoaded, isCometScanE import org.apache.comet.DataTypeSupport.isComplexType import org.apache.comet.objectstore.NativeConfig import org.apache.comet.parquet.{CometParquetScan, Native, SupportsComet} +import org.apache.comet.parquet.CometParquetUtils.{encryptionEnabled, isEncryptionConfigSupported} import org.apache.comet.shims.CometTypeShim /** @@ -147,9 +148,12 @@ case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] with Com var scanImpl = COMET_NATIVE_SCAN_IMPL.get() + val hadoopConf = scanExec.relation.sparkSession.sessionState + .newHadoopConfWithOptions(scanExec.relation.options) + // if scan is auto then pick the best available scan if (scanImpl == SCAN_AUTO) { - scanImpl = selectScan(scanExec, r.partitionSchema) + scanImpl = selectScan(scanExec, r.partitionSchema, hadoopConf) } if (scanImpl == SCAN_NATIVE_DATAFUSION && !COMET_EXEC_ENABLED.get()) { @@ -196,6 +200,12 @@ case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] with Com return withInfos(scanExec, fallbackReasons.toSet) } + if (scanImpl != CometConf.SCAN_NATIVE_COMET && encryptionEnabled(hadoopConf)) { + if (!isEncryptionConfigSupported(hadoopConf)) { + return withInfos(scanExec, fallbackReasons.toSet) + } + } + val typeChecker = CometScanTypeChecker(scanImpl) val schemaSupported = typeChecker.isSchemaSupported(scanExec.requiredSchema, fallbackReasons) @@ -287,7 +297,10 @@ case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] with Com } } - private def selectScan(scanExec: FileSourceScanExec, partitionSchema: StructType): String = { + private def selectScan( + scanExec: FileSourceScanExec, + partitionSchema: StructType, + hadoopConf: Configuration): String = { val fallbackReasons = new ListBuffer[String]() @@ -297,10 +310,7 @@ case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] with Com val filePath = scanExec.relation.inputFiles.headOption if (filePath.exists(_.startsWith("s3a://"))) { - validateObjectStoreConfig( - filePath.get, - session.sparkContext.hadoopConfiguration, - fallbackReasons) + validateObjectStoreConfig(filePath.get, hadoopConf, fallbackReasons) } } else { fallbackReasons += s"$SCAN_NATIVE_ICEBERG_COMPAT only supports local filesystem and S3" diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/ParquetEncryptionITCase.scala b/spark/src/test/scala/org/apache/spark/sql/comet/ParquetEncryptionITCase.scala index 45b89735e5..8503fbc66f 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/ParquetEncryptionITCase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/ParquetEncryptionITCase.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.comet -import java.io.File -import java.io.RandomAccessFile +import java.io.{File, RandomAccessFile} import java.nio.charset.StandardCharsets import java.util.Base64 @@ -29,6 +28,9 @@ import org.scalactic.source.Position import org.scalatest.Tag import org.scalatestplus.junit.JUnitRunner +import org.apache.parquet.crypto.DecryptionPropertiesFactory +import org.apache.parquet.crypto.keytools.{KeyToolkit, PropertiesDrivenCryptoFactory} +import org.apache.parquet.crypto.keytools.mocks.InMemoryKMS import org.apache.spark.{DebugFilesystem, SparkConf} import org.apache.spark.sql.{CometTestBase, SQLContext} import org.apache.spark.sql.internal.SQLConf @@ -57,10 +59,10 @@ class ParquetEncryptionITCase extends CometTestBase with SQLTestUtils { factoryClass => withTempDir { dir => withSQLConf( - "parquet.crypto.factory.class" -> factoryClass, - "parquet.encryption.kms.client.class" -> + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> factoryClass, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", - "parquet.encryption.key.list" -> + InMemoryKMS.KEY_LIST_PROPERTY_NAME -> s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") { // Make sure encryption works with multiple Parquet files @@ -71,8 +73,10 @@ class ParquetEncryptionITCase extends CometTestBase with SQLTestUtils { .toDF("a", "b", "c") val parquetDir = new File(dir, "parquet").getCanonicalPath inputDF.write - .option("parquet.encryption.column.keys", "key1: a, b; key2: c") - .option("parquet.encryption.footer.key", "footerKey") + .option( + PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, + "key1: a, b; key2: c") + .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") .parquet(parquetDir) verifyParquetEncrypted(parquetDir) @@ -99,11 +103,11 @@ class ParquetEncryptionITCase extends CometTestBase with SQLTestUtils { factoryClass => withTempDir { dir => withSQLConf( - "parquet.crypto.factory.class" -> factoryClass, - "parquet.encryption.kms.client.class" -> + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> factoryClass, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", - "parquet.encryption.key.material.store.internally" -> "false", - "parquet.encryption.key.list" -> + KeyToolkit.KEY_MATERIAL_INTERNAL_PROPERTY_NAME -> "false", // default is true + InMemoryKMS.KEY_LIST_PROPERTY_NAME -> s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") { val inputDF = spark @@ -113,8 +117,10 @@ class ParquetEncryptionITCase extends CometTestBase with SQLTestUtils { .toDF("a", "b", "c") val parquetDir = new File(dir, "parquet").getCanonicalPath inputDF.write - .option("parquet.encryption.column.keys", "key1: a, b; key2: c") - .option("parquet.encryption.footer.key", "footerKey") + .option( + PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, + "key1: a, b; key2: c") + .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") .parquet(parquetDir) val parquetDF = spark.read.parquet(parquetDir) @@ -131,6 +137,248 @@ class ParquetEncryptionITCase extends CometTestBase with SQLTestUtils { } } + test("Plain text footer mode") { + import testImplicits._ + + Seq("org.apache.parquet.crypto.keytools.PropertiesDrivenCryptoFactory").foreach { + factoryClass => + withTempDir { dir => + withSQLConf( + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> factoryClass, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> + "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", + PropertiesDrivenCryptoFactory.PLAINTEXT_FOOTER_PROPERTY_NAME -> "true", // default is false + InMemoryKMS.KEY_LIST_PROPERTY_NAME -> + s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") { + + val inputDF = spark + .range(0, 1000) + .map(i => (i, i.toString, i.toFloat)) + .repartition(5) + .toDF("a", "b", "c") + val parquetDir = new File(dir, "parquet").getCanonicalPath + inputDF.write + .option( + PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, + "key1: a, b; key2: c") + .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") + .parquet(parquetDir) + + verifyParquetPlaintextFooter(parquetDir) + + val parquetDF = spark.read.parquet(parquetDir) + assert(parquetDF.inputFiles.nonEmpty) + val readDataset = parquetDF.select("a", "b", "c") + + if (CometConf.COMET_ENABLED.get(conf)) { + checkSparkAnswerAndOperator(readDataset) + } else { + checkAnswer(readDataset, inputDF) + } + } + } + } + } + + test("Change encryption algorithm") { + import testImplicits._ + + Seq("org.apache.parquet.crypto.keytools.PropertiesDrivenCryptoFactory").foreach { + factoryClass => + withTempDir { dir => + withSQLConf( + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> factoryClass, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> + "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", + // default is AES_GCM_V1 + PropertiesDrivenCryptoFactory.ENCRYPTION_ALGORITHM_PROPERTY_NAME -> "AES_GCM_CTR_V1", + InMemoryKMS.KEY_LIST_PROPERTY_NAME -> + s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") { + + val inputDF = spark + .range(0, 1000) + .map(i => (i, i.toString, i.toFloat)) + .repartition(5) + .toDF("a", "b", "c") + val parquetDir = new File(dir, "parquet").getCanonicalPath + inputDF.write + .option( + PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, + "key1: a, b; key2: c") + .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") + .parquet(parquetDir) + + verifyParquetEncrypted(parquetDir) + + val parquetDF = spark.read.parquet(parquetDir) + assert(parquetDF.inputFiles.nonEmpty) + val readDataset = parquetDF.select("a", "b", "c") + + // native_datafusion and native_iceberg_compat fall back due to Arrow-rs + // https://github.com/apache/arrow-rs/blob/main/parquet/src/file/metadata/parser.rs#L414 + if (CometConf.COMET_ENABLED.get(conf) && CometConf.COMET_NATIVE_SCAN_IMPL.get( + conf) == SCAN_NATIVE_COMET) { + checkSparkAnswerAndOperator(readDataset) + } else { + checkAnswer(readDataset, inputDF) + } + } + } + } + } + + test("Test double wrapping disabled") { + import testImplicits._ + + Seq("org.apache.parquet.crypto.keytools.PropertiesDrivenCryptoFactory").foreach { + factoryClass => + withTempDir { dir => + withSQLConf( + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> factoryClass, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> + "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", + KeyToolkit.DOUBLE_WRAPPING_PROPERTY_NAME -> "false", // default is true + InMemoryKMS.KEY_LIST_PROPERTY_NAME -> + s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") { + + val inputDF = spark + .range(0, 1000) + .map(i => (i, i.toString, i.toFloat)) + .repartition(5) + .toDF("a", "b", "c") + val parquetDir = new File(dir, "parquet").getCanonicalPath + inputDF.write + .option( + PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, + "key1: a, b; key2: c") + .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") + .parquet(parquetDir) + + verifyParquetEncrypted(parquetDir) + + val parquetDF = spark.read.parquet(parquetDir) + assert(parquetDF.inputFiles.nonEmpty) + val readDataset = parquetDF.select("a", "b", "c") + + if (CometConf.COMET_ENABLED.get(conf)) { + checkSparkAnswerAndOperator(readDataset) + } else { + checkAnswer(readDataset, inputDF) + } + } + } + } + } + + test("Join between files with different encryption keys") { + import testImplicits._ + + Seq("org.apache.parquet.crypto.keytools.PropertiesDrivenCryptoFactory").foreach { + factoryClass => + withTempDir { dir => + withSQLConf( + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> factoryClass, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> + "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", + InMemoryKMS.KEY_LIST_PROPERTY_NAME -> + s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") { + + // Write first file + val inputDF1 = spark + .range(0, 100) + .map(i => (i, s"file1_${i}", i.toFloat)) + .toDF("id", "name", "value") + val parquetDir1 = new File(dir, "parquet1").getCanonicalPath + inputDF1.write + .option( + PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, + "key1: id, name, value") + .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") + .parquet(parquetDir1) + + // Write second file using different column key + val inputDF2 = spark + .range(0, 100) + .map(i => (i, s"file2_${i}", (i * 2).toFloat)) + .toDF("id", "description", "score") + val parquetDir2 = new File(dir, "parquet2").getCanonicalPath + inputDF2.write + .option( + PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, + "key2: id, description, score") + .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") + .parquet(parquetDir2) + + // Now perform a join between the two files with different encryption keys + // This tests that hadoopConf properties propagate correctly to each scan + val parquetDF1 = spark.read.parquet(parquetDir1).alias("f1") + val parquetDF2 = spark.read.parquet(parquetDir2).alias("f2") + + val joinedDF = parquetDF1 + .join(parquetDF2, parquetDF1("id") === parquetDF2("id"), "inner") + .select( + parquetDF1("id"), + parquetDF1("name"), + parquetDF2("description"), + parquetDF2("score")) + + if (CometConf.COMET_ENABLED.get(conf)) { + checkSparkAnswerAndOperator(joinedDF) + } else { + checkSparkAnswer(joinedDF) + } + } + } + } + } + + test("Test different key lengths") { + import testImplicits._ + + Seq("org.apache.parquet.crypto.keytools.PropertiesDrivenCryptoFactory").foreach { + factoryClass => + withTempDir { dir => + withSQLConf( + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> factoryClass, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> + "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", + KeyToolkit.DATA_KEY_LENGTH_PROPERTY_NAME -> "256", + KeyToolkit.KEK_LENGTH_PROPERTY_NAME -> "256", + InMemoryKMS.KEY_LIST_PROPERTY_NAME -> + s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") { + + val inputDF = spark + .range(0, 1000) + .map(i => (i, i.toString, i.toFloat)) + .repartition(5) + .toDF("a", "b", "c") + val parquetDir = new File(dir, "parquet").getCanonicalPath + inputDF.write + .option( + PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, + "key1: a, b; key2: c") + .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") + .parquet(parquetDir) + + verifyParquetEncrypted(parquetDir) + + val parquetDF = spark.read.parquet(parquetDir) + assert(parquetDF.inputFiles.nonEmpty) + val readDataset = parquetDF.select("a", "b", "c") + + // native_datafusion and native_iceberg_compat fall back due to Arrow-rs not + // supporting other key lengths + if (CometConf.COMET_ENABLED.get(conf) && CometConf.COMET_NATIVE_SCAN_IMPL.get( + conf) == SCAN_NATIVE_COMET) { + checkSparkAnswerAndOperator(readDataset) + } else { + checkAnswer(readDataset, inputDF) + } + } + } + } + } + protected override def sparkConf: SparkConf = { val conf = new SparkConf() conf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) @@ -148,7 +396,7 @@ class ParquetEncryptionITCase extends CometTestBase with SQLTestUtils { if (cometEnabled == "true") { Seq(SCAN_NATIVE_COMET, SCAN_NATIVE_DATAFUSION, SCAN_NATIVE_ICEBERG_COMPAT).foreach { scanImpl => - super.test(testName + s" Comet($cometEnabled)" + s"Scan($scanImpl)", testTags: _*) { + super.test(testName + s" Comet($cometEnabled)" + s" Scan($scanImpl)", testTags: _*) { withSQLConf( CometConf.COMET_ENABLED.key -> cometEnabled, CometConf.COMET_EXEC_ENABLED.key -> "true", @@ -206,6 +454,30 @@ class ParquetEncryptionITCase extends CometTestBase with SQLTestUtils { } } + /** + * Verify that the directory contains an encrypted parquet in plaintext footer mode by means of + * checking for all the parquet part files in the parquet directory that their magic string is + * PAR1, as defined in the spec: + * https://github.com/apache/parquet-format/blob/master/Encryption.md#55-plaintext-footer-mode + */ + private def verifyParquetPlaintextFooter(parquetDir: String): Unit = { + val parquetPartitionFiles = getListOfParquetFiles(new File(parquetDir)) + assert(parquetPartitionFiles.size >= 1) + parquetPartitionFiles.foreach { parquetFile => + val magicString = "PAR1" + val magicStringLength = magicString.length() + val byteArray = new Array[Byte](magicStringLength) + val randomAccessFile = new RandomAccessFile(parquetFile, "r") + try { + randomAccessFile.read(byteArray, 0, magicStringLength) + } finally { + randomAccessFile.close() + } + val stringRead = new String(byteArray, StandardCharsets.UTF_8) + assert(magicString == stringRead) + } + } + private def getListOfParquetFiles(dir: File): List[File] = { dir.listFiles.filter(_.isFile).toList.filter { file => file.getName.endsWith("parquet") From bf0bec42310cddf838eb9e299c82e313f41c6555 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Fri, 26 Sep 2025 15:37:53 -0400 Subject: [PATCH 07/19] Cleanup Seq loop that wasn't doing anything. --- .../sql/comet/ParquetEncryptionITCase.scala | 523 ++++++++---------- 1 file changed, 246 insertions(+), 277 deletions(-) diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/ParquetEncryptionITCase.scala b/spark/src/test/scala/org/apache/spark/sql/comet/ParquetEncryptionITCase.scala index 8503fbc66f..b7875d3100 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/ParquetEncryptionITCase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/ParquetEncryptionITCase.scala @@ -50,48 +50,45 @@ class ParquetEncryptionITCase extends CometTestBase with SQLTestUtils { encoder.encodeToString("0123456789012345".getBytes(StandardCharsets.UTF_8)) private val key1 = encoder.encodeToString("1234567890123450".getBytes(StandardCharsets.UTF_8)) private val key2 = encoder.encodeToString("1234567890123451".getBytes(StandardCharsets.UTF_8)) + private val cryptoFactoryClass = + "org.apache.parquet.crypto.keytools.PropertiesDrivenCryptoFactory" test("SPARK-34990: Write and read an encrypted parquet") { import testImplicits._ - Seq("org.apache.parquet.crypto.keytools.PropertiesDrivenCryptoFactory").foreach { - factoryClass => - withTempDir { dir => - withSQLConf( - DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> factoryClass, - KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> - "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", - InMemoryKMS.KEY_LIST_PROPERTY_NAME -> - s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") { - - // Make sure encryption works with multiple Parquet files - val inputDF = spark - .range(0, 2000) - .map(i => (i, i.toString, i.toFloat)) - .repartition(10) - .toDF("a", "b", "c") - val parquetDir = new File(dir, "parquet").getCanonicalPath - inputDF.write - .option( - PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, - "key1: a, b; key2: c") - .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") - .parquet(parquetDir) - - verifyParquetEncrypted(parquetDir) - - val parquetDF = spark.read.parquet(parquetDir) - assert(parquetDF.inputFiles.nonEmpty) - val readDataset = parquetDF.select("a", "b", "c") - - if (CometConf.COMET_ENABLED.get(conf)) { - checkSparkAnswerAndOperator(readDataset) - } else { - checkAnswer(readDataset, inputDF) - } - } + withTempDir { dir => + withSQLConf( + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> cryptoFactoryClass, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> + "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", + InMemoryKMS.KEY_LIST_PROPERTY_NAME -> + s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") { + + // Make sure encryption works with multiple Parquet files + val inputDF = spark + .range(0, 2000) + .map(i => (i, i.toString, i.toFloat)) + .repartition(10) + .toDF("a", "b", "c") + val parquetDir = new File(dir, "parquet").getCanonicalPath + inputDF.write + .option(PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, "key1: a, b; key2: c") + .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") + .parquet(parquetDir) + + verifyParquetEncrypted(parquetDir) + + val parquetDF = spark.read.parquet(parquetDir) + assert(parquetDF.inputFiles.nonEmpty) + val readDataset = parquetDF.select("a", "b", "c") + + if (CometConf.COMET_ENABLED.get(conf)) { + checkSparkAnswerAndOperator(readDataset) + } else { + checkAnswer(readDataset, inputDF) } + } } } @@ -99,283 +96,255 @@ class ParquetEncryptionITCase extends CometTestBase with SQLTestUtils { import testImplicits._ - Seq("org.apache.parquet.crypto.keytools.PropertiesDrivenCryptoFactory").foreach { - factoryClass => - withTempDir { dir => - withSQLConf( - DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> factoryClass, - KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> - "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", - KeyToolkit.KEY_MATERIAL_INTERNAL_PROPERTY_NAME -> "false", // default is true - InMemoryKMS.KEY_LIST_PROPERTY_NAME -> - s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") { - - val inputDF = spark - .range(0, 2000) - .map(i => (i, i.toString, i.toFloat)) - .repartition(10) - .toDF("a", "b", "c") - val parquetDir = new File(dir, "parquet").getCanonicalPath - inputDF.write - .option( - PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, - "key1: a, b; key2: c") - .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") - .parquet(parquetDir) - - val parquetDF = spark.read.parquet(parquetDir) - assert(parquetDF.inputFiles.nonEmpty) - val readDataset = parquetDF.select("a", "b", "c") - - if (CometConf.COMET_ENABLED.get(conf)) { - checkSparkAnswerAndOperator(readDataset) - } else { - checkAnswer(readDataset, inputDF) - } - } + withTempDir { dir => + withSQLConf( + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> cryptoFactoryClass, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> + "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", + KeyToolkit.KEY_MATERIAL_INTERNAL_PROPERTY_NAME -> "false", // default is true + InMemoryKMS.KEY_LIST_PROPERTY_NAME -> + s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") { + + val inputDF = spark + .range(0, 2000) + .map(i => (i, i.toString, i.toFloat)) + .repartition(10) + .toDF("a", "b", "c") + val parquetDir = new File(dir, "parquet").getCanonicalPath + inputDF.write + .option(PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, "key1: a, b; key2: c") + .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") + .parquet(parquetDir) + + val parquetDF = spark.read.parquet(parquetDir) + assert(parquetDF.inputFiles.nonEmpty) + val readDataset = parquetDF.select("a", "b", "c") + + if (CometConf.COMET_ENABLED.get(conf)) { + checkSparkAnswerAndOperator(readDataset) + } else { + checkAnswer(readDataset, inputDF) } + } } } test("Plain text footer mode") { import testImplicits._ - Seq("org.apache.parquet.crypto.keytools.PropertiesDrivenCryptoFactory").foreach { - factoryClass => - withTempDir { dir => - withSQLConf( - DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> factoryClass, - KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> - "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", - PropertiesDrivenCryptoFactory.PLAINTEXT_FOOTER_PROPERTY_NAME -> "true", // default is false - InMemoryKMS.KEY_LIST_PROPERTY_NAME -> - s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") { - - val inputDF = spark - .range(0, 1000) - .map(i => (i, i.toString, i.toFloat)) - .repartition(5) - .toDF("a", "b", "c") - val parquetDir = new File(dir, "parquet").getCanonicalPath - inputDF.write - .option( - PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, - "key1: a, b; key2: c") - .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") - .parquet(parquetDir) - - verifyParquetPlaintextFooter(parquetDir) - - val parquetDF = spark.read.parquet(parquetDir) - assert(parquetDF.inputFiles.nonEmpty) - val readDataset = parquetDF.select("a", "b", "c") - - if (CometConf.COMET_ENABLED.get(conf)) { - checkSparkAnswerAndOperator(readDataset) - } else { - checkAnswer(readDataset, inputDF) - } - } + withTempDir { dir => + withSQLConf( + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> cryptoFactoryClass, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> + "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", + PropertiesDrivenCryptoFactory.PLAINTEXT_FOOTER_PROPERTY_NAME -> "true", // default is false + InMemoryKMS.KEY_LIST_PROPERTY_NAME -> + s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") { + + val inputDF = spark + .range(0, 1000) + .map(i => (i, i.toString, i.toFloat)) + .repartition(5) + .toDF("a", "b", "c") + val parquetDir = new File(dir, "parquet").getCanonicalPath + inputDF.write + .option(PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, "key1: a, b; key2: c") + .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") + .parquet(parquetDir) + + verifyParquetPlaintextFooter(parquetDir) + + val parquetDF = spark.read.parquet(parquetDir) + assert(parquetDF.inputFiles.nonEmpty) + val readDataset = parquetDF.select("a", "b", "c") + + if (CometConf.COMET_ENABLED.get(conf)) { + checkSparkAnswerAndOperator(readDataset) + } else { + checkAnswer(readDataset, inputDF) } + } } } test("Change encryption algorithm") { import testImplicits._ - Seq("org.apache.parquet.crypto.keytools.PropertiesDrivenCryptoFactory").foreach { - factoryClass => - withTempDir { dir => - withSQLConf( - DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> factoryClass, - KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> - "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", - // default is AES_GCM_V1 - PropertiesDrivenCryptoFactory.ENCRYPTION_ALGORITHM_PROPERTY_NAME -> "AES_GCM_CTR_V1", - InMemoryKMS.KEY_LIST_PROPERTY_NAME -> - s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") { - - val inputDF = spark - .range(0, 1000) - .map(i => (i, i.toString, i.toFloat)) - .repartition(5) - .toDF("a", "b", "c") - val parquetDir = new File(dir, "parquet").getCanonicalPath - inputDF.write - .option( - PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, - "key1: a, b; key2: c") - .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") - .parquet(parquetDir) - - verifyParquetEncrypted(parquetDir) - - val parquetDF = spark.read.parquet(parquetDir) - assert(parquetDF.inputFiles.nonEmpty) - val readDataset = parquetDF.select("a", "b", "c") - - // native_datafusion and native_iceberg_compat fall back due to Arrow-rs - // https://github.com/apache/arrow-rs/blob/main/parquet/src/file/metadata/parser.rs#L414 - if (CometConf.COMET_ENABLED.get(conf) && CometConf.COMET_NATIVE_SCAN_IMPL.get( - conf) == SCAN_NATIVE_COMET) { - checkSparkAnswerAndOperator(readDataset) - } else { - checkAnswer(readDataset, inputDF) - } - } + withTempDir { dir => + withSQLConf( + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> cryptoFactoryClass, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> + "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", + // default is AES_GCM_V1 + PropertiesDrivenCryptoFactory.ENCRYPTION_ALGORITHM_PROPERTY_NAME -> "AES_GCM_CTR_V1", + InMemoryKMS.KEY_LIST_PROPERTY_NAME -> + s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") { + + val inputDF = spark + .range(0, 1000) + .map(i => (i, i.toString, i.toFloat)) + .repartition(5) + .toDF("a", "b", "c") + val parquetDir = new File(dir, "parquet").getCanonicalPath + inputDF.write + .option(PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, "key1: a, b; key2: c") + .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") + .parquet(parquetDir) + + verifyParquetEncrypted(parquetDir) + + val parquetDF = spark.read.parquet(parquetDir) + assert(parquetDF.inputFiles.nonEmpty) + val readDataset = parquetDF.select("a", "b", "c") + + // native_datafusion and native_iceberg_compat fall back due to Arrow-rs + // https://github.com/apache/arrow-rs/blob/main/parquet/src/file/metadata/parser.rs#L414 + if (CometConf.COMET_ENABLED.get(conf) && CometConf.COMET_NATIVE_SCAN_IMPL.get( + conf) == SCAN_NATIVE_COMET) { + checkSparkAnswerAndOperator(readDataset) + } else { + checkAnswer(readDataset, inputDF) } + } } } test("Test double wrapping disabled") { import testImplicits._ - Seq("org.apache.parquet.crypto.keytools.PropertiesDrivenCryptoFactory").foreach { - factoryClass => - withTempDir { dir => - withSQLConf( - DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> factoryClass, - KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> - "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", - KeyToolkit.DOUBLE_WRAPPING_PROPERTY_NAME -> "false", // default is true - InMemoryKMS.KEY_LIST_PROPERTY_NAME -> - s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") { - - val inputDF = spark - .range(0, 1000) - .map(i => (i, i.toString, i.toFloat)) - .repartition(5) - .toDF("a", "b", "c") - val parquetDir = new File(dir, "parquet").getCanonicalPath - inputDF.write - .option( - PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, - "key1: a, b; key2: c") - .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") - .parquet(parquetDir) - - verifyParquetEncrypted(parquetDir) - - val parquetDF = spark.read.parquet(parquetDir) - assert(parquetDF.inputFiles.nonEmpty) - val readDataset = parquetDF.select("a", "b", "c") - - if (CometConf.COMET_ENABLED.get(conf)) { - checkSparkAnswerAndOperator(readDataset) - } else { - checkAnswer(readDataset, inputDF) - } - } + withTempDir { dir => + withSQLConf( + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> cryptoFactoryClass, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> + "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", + KeyToolkit.DOUBLE_WRAPPING_PROPERTY_NAME -> "false", // default is true + InMemoryKMS.KEY_LIST_PROPERTY_NAME -> + s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") { + + val inputDF = spark + .range(0, 1000) + .map(i => (i, i.toString, i.toFloat)) + .repartition(5) + .toDF("a", "b", "c") + val parquetDir = new File(dir, "parquet").getCanonicalPath + inputDF.write + .option(PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, "key1: a, b; key2: c") + .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") + .parquet(parquetDir) + + verifyParquetEncrypted(parquetDir) + + val parquetDF = spark.read.parquet(parquetDir) + assert(parquetDF.inputFiles.nonEmpty) + val readDataset = parquetDF.select("a", "b", "c") + + if (CometConf.COMET_ENABLED.get(conf)) { + checkSparkAnswerAndOperator(readDataset) + } else { + checkAnswer(readDataset, inputDF) } + } } } test("Join between files with different encryption keys") { import testImplicits._ - Seq("org.apache.parquet.crypto.keytools.PropertiesDrivenCryptoFactory").foreach { - factoryClass => - withTempDir { dir => - withSQLConf( - DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> factoryClass, - KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> - "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", - InMemoryKMS.KEY_LIST_PROPERTY_NAME -> - s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") { - - // Write first file - val inputDF1 = spark - .range(0, 100) - .map(i => (i, s"file1_${i}", i.toFloat)) - .toDF("id", "name", "value") - val parquetDir1 = new File(dir, "parquet1").getCanonicalPath - inputDF1.write - .option( - PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, - "key1: id, name, value") - .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") - .parquet(parquetDir1) - - // Write second file using different column key - val inputDF2 = spark - .range(0, 100) - .map(i => (i, s"file2_${i}", (i * 2).toFloat)) - .toDF("id", "description", "score") - val parquetDir2 = new File(dir, "parquet2").getCanonicalPath - inputDF2.write - .option( - PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, - "key2: id, description, score") - .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") - .parquet(parquetDir2) - - // Now perform a join between the two files with different encryption keys - // This tests that hadoopConf properties propagate correctly to each scan - val parquetDF1 = spark.read.parquet(parquetDir1).alias("f1") - val parquetDF2 = spark.read.parquet(parquetDir2).alias("f2") - - val joinedDF = parquetDF1 - .join(parquetDF2, parquetDF1("id") === parquetDF2("id"), "inner") - .select( - parquetDF1("id"), - parquetDF1("name"), - parquetDF2("description"), - parquetDF2("score")) - - if (CometConf.COMET_ENABLED.get(conf)) { - checkSparkAnswerAndOperator(joinedDF) - } else { - checkSparkAnswer(joinedDF) - } - } + withTempDir { dir => + withSQLConf( + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> cryptoFactoryClass, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> + "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", + InMemoryKMS.KEY_LIST_PROPERTY_NAME -> + s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") { + + // Write first file + val inputDF1 = spark + .range(0, 100) + .map(i => (i, s"file1_${i}", i.toFloat)) + .toDF("id", "name", "value") + val parquetDir1 = new File(dir, "parquet1").getCanonicalPath + inputDF1.write + .option( + PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, + "key1: id, name, value") + .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") + .parquet(parquetDir1) + + // Write second file using different column key + val inputDF2 = spark + .range(0, 100) + .map(i => (i, s"file2_${i}", (i * 2).toFloat)) + .toDF("id", "description", "score") + val parquetDir2 = new File(dir, "parquet2").getCanonicalPath + inputDF2.write + .option( + PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, + "key2: id, description, score") + .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") + .parquet(parquetDir2) + + // Now perform a join between the two files with different encryption keys + // This tests that hadoopConf properties propagate correctly to each scan + val parquetDF1 = spark.read.parquet(parquetDir1).alias("f1") + val parquetDF2 = spark.read.parquet(parquetDir2).alias("f2") + + val joinedDF = parquetDF1 + .join(parquetDF2, parquetDF1("id") === parquetDF2("id"), "inner") + .select( + parquetDF1("id"), + parquetDF1("name"), + parquetDF2("description"), + parquetDF2("score")) + + if (CometConf.COMET_ENABLED.get(conf)) { + checkSparkAnswerAndOperator(joinedDF) + } else { + checkSparkAnswer(joinedDF) } + } } } test("Test different key lengths") { import testImplicits._ - Seq("org.apache.parquet.crypto.keytools.PropertiesDrivenCryptoFactory").foreach { - factoryClass => - withTempDir { dir => - withSQLConf( - DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> factoryClass, - KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> - "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", - KeyToolkit.DATA_KEY_LENGTH_PROPERTY_NAME -> "256", - KeyToolkit.KEK_LENGTH_PROPERTY_NAME -> "256", - InMemoryKMS.KEY_LIST_PROPERTY_NAME -> - s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") { - - val inputDF = spark - .range(0, 1000) - .map(i => (i, i.toString, i.toFloat)) - .repartition(5) - .toDF("a", "b", "c") - val parquetDir = new File(dir, "parquet").getCanonicalPath - inputDF.write - .option( - PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, - "key1: a, b; key2: c") - .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") - .parquet(parquetDir) - - verifyParquetEncrypted(parquetDir) - - val parquetDF = spark.read.parquet(parquetDir) - assert(parquetDF.inputFiles.nonEmpty) - val readDataset = parquetDF.select("a", "b", "c") - - // native_datafusion and native_iceberg_compat fall back due to Arrow-rs not - // supporting other key lengths - if (CometConf.COMET_ENABLED.get(conf) && CometConf.COMET_NATIVE_SCAN_IMPL.get( - conf) == SCAN_NATIVE_COMET) { - checkSparkAnswerAndOperator(readDataset) - } else { - checkAnswer(readDataset, inputDF) - } - } + withTempDir { dir => + withSQLConf( + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> cryptoFactoryClass, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> + "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", + KeyToolkit.DATA_KEY_LENGTH_PROPERTY_NAME -> "256", + KeyToolkit.KEK_LENGTH_PROPERTY_NAME -> "256", + InMemoryKMS.KEY_LIST_PROPERTY_NAME -> + s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") { + + val inputDF = spark + .range(0, 1000) + .map(i => (i, i.toString, i.toFloat)) + .repartition(5) + .toDF("a", "b", "c") + val parquetDir = new File(dir, "parquet").getCanonicalPath + inputDF.write + .option(PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, "key1: a, b; key2: c") + .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") + .parquet(parquetDir) + + verifyParquetEncrypted(parquetDir) + + val parquetDF = spark.read.parquet(parquetDir) + assert(parquetDF.inputFiles.nonEmpty) + val readDataset = parquetDF.select("a", "b", "c") + + // native_datafusion and native_iceberg_compat fall back due to Arrow-rs not + // supporting other key lengths + if (CometConf.COMET_ENABLED.get(conf) && CometConf.COMET_NATIVE_SCAN_IMPL.get( + conf) == SCAN_NATIVE_COMET) { + checkSparkAnswerAndOperator(readDataset) + } else { + checkAnswer(readDataset, inputDF) } + } } } From 271e940366762d3f2c928fb732006badecd9c6a1 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Fri, 26 Sep 2025 15:43:08 -0400 Subject: [PATCH 08/19] Docs. --- .../org/apache/comet/parquet/CometFileKeyUnwrapper.java | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/common/src/main/java/org/apache/comet/parquet/CometFileKeyUnwrapper.java b/common/src/main/java/org/apache/comet/parquet/CometFileKeyUnwrapper.java index ea4037dbf0..73025457ef 100644 --- a/common/src/main/java/org/apache/comet/parquet/CometFileKeyUnwrapper.java +++ b/common/src/main/java/org/apache/comet/parquet/CometFileKeyUnwrapper.java @@ -30,8 +30,9 @@ /** * Helper class to access DecryptionKeyRetriever.getKey from native code via JNI. This class handles - * the complexity of getting the proper Hadoop Configuration from the current Spark context and - * creating properly configured DecryptionKeyRetriever instances using DecryptionPropertiesFactory. + * the complexity of creating and caching properly configured DecryptionKeyRetriever instances using + * DecryptionPropertiesFactory. The life of this object is meant to map to a single Comet plan, so + * associated with CometExecIterator. */ public class CometFileKeyUnwrapper { @@ -40,7 +41,7 @@ public class CometFileKeyUnwrapper { new ConcurrentHashMap<>(); // Each hadoopConf yields a unique DecryptionPropertiesFactory. While it's unlikely that - // this plan contains more than one hadoopConf, we don't want to assume that. So we'll + // this Comet plan contains more than one hadoopConf, we don't want to assume that. So we'll // provide the ability to cache more than one Factory with a map. private final ConcurrentHashMap factoryCache = new ConcurrentHashMap<>(); From 571c8815a8da85819667555ba3b9b8571b78ee0d Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Fri, 26 Sep 2025 15:59:51 -0400 Subject: [PATCH 09/19] Docs. --- .../comet/parquet/CometFileKeyUnwrapper.java | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/common/src/main/java/org/apache/comet/parquet/CometFileKeyUnwrapper.java b/common/src/main/java/org/apache/comet/parquet/CometFileKeyUnwrapper.java index 73025457ef..305cce414f 100644 --- a/common/src/main/java/org/apache/comet/parquet/CometFileKeyUnwrapper.java +++ b/common/src/main/java/org/apache/comet/parquet/CometFileKeyUnwrapper.java @@ -28,6 +28,62 @@ import org.apache.parquet.crypto.FileDecryptionProperties; import org.apache.parquet.crypto.ParquetCryptoRuntimeException; +// spotless:off +/* + * Architecture Overview: + * + * JVM Side | Native Side + * ┌─────────────────────────────────────┐ | ┌─────────────────────────────────────┐ + * │ CometFileKeyUnwrapper │ | │ Parquet File Reading │ + * │ │ | │ │ + * │ ┌─────────────────────────────┐ │ | │ ┌─────────────────────────────┐ │ + * │ │ hadoopConf │ │ | │ │ file1.parquet │ │ + * │ │ (Configuration) │ │ | │ │ file2.parquet │ │ + * │ └─────────────────────────────┘ │ | │ │ file3.parquet │ │ + * │ │ │ | │ └─────────────────────────────┘ │ + * │ ▼ │ | │ │ │ + * │ ┌─────────────────────────────┐ │ | │ │ │ + * │ │ factoryCache │ │ | │ ▼ │ + * │ │ (many-to-one mapping) │ │ | │ ┌─────────────────────────────┐ │ + * │ │ │ │ | │ │ Parse file metadata & │ │ + * │ │ file1 ──┐ │ │ | │ │ extract keyMetadata │ │ + * │ │ file2 ──┼─► DecryptionProps │ │ | │ └─────────────────────────────┘ │ + * │ │ file3 ──┘ Factory │ │ | │ │ │ + * │ └─────────────────────────────┘ │ | │ │ │ + * │ │ │ | │ ▼ │ + * │ ▼ │ | │ ╔═════════════════════════════╗ │ + * │ ┌─────────────────────────────┐ │ | │ ║ JNI CALL: ║ │ + * │ │ retrieverCache │ │ | │ ║ getKey(filePath, ║ │ + * │ │ filePath -> KeyRetriever │◄───┼───┼───┼──║ keyMetadata) ║ │ + * │ └─────────────────────────────┘ │ | │ ╚═════════════════════════════╝ │ + * │ │ │ | │ │ + * │ ▼ │ | │ │ + * │ ┌─────────────────────────────┐ │ | │ │ + * │ │ DecryptionKeyRetriever │ │ | │ │ + * │ │ .getKey(keyMetadata) │ │ | │ │ + * │ └─────────────────────────────┘ │ | │ │ + * │ │ │ | │ │ + * │ ▼ │ | │ │ + * │ ┌─────────────────────────────┐ │ | │ ┌─────────────────────────────┐ │ + * │ │ return key bytes │────┼───┼───┼─►│ Use key for decryption │ │ + * │ └─────────────────────────────┘ │ | │ │ of parquet data │ │ + * └─────────────────────────────────────┘ | │ └─────────────────────────────┘ │ + * | └─────────────────────────────────────┘ + * | + * JNI Boundary + * + * Setup Phase (storeDecryptionKeyRetriever): + * 1. hadoopConf → DecryptionPropertiesFactory (cached in factoryCache) + * 2. Factory + filePath → DecryptionKeyRetriever (cached in retrieverCache) + * + * Runtime Phase (getKey): + * 3. Native code calls getKey(filePath, keyMetadata) ──► JVM + * 4. Retrieve cached DecryptionKeyRetriever for filePath + * 5. KeyRetriever.getKey(keyMetadata) → decrypted key bytes + * 6. Return key bytes ──► Native code for parquet decryption + */ +// spotless:on + /** * Helper class to access DecryptionKeyRetriever.getKey from native code via JNI. This class handles * the complexity of creating and caching properly configured DecryptionKeyRetriever instances using From 4dde7fbee1c7cf1eaad3bc6e94f460b617b328b2 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Fri, 26 Sep 2025 16:30:54 -0400 Subject: [PATCH 10/19] Refactor out of parquet_exec.rs. --- native/core/src/execution/jni_api.rs | 2 +- native/core/src/parquet/encryption_support.rs | 151 ++++++++++++++++++ native/core/src/parquet/mod.rs | 6 +- native/core/src/parquet/parquet_exec.rs | 132 +-------------- 4 files changed, 157 insertions(+), 134 deletions(-) create mode 100644 native/core/src/parquet/encryption_support.rs diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index a120e8a29a..b722a96087 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -75,7 +75,7 @@ use crate::execution::spark_plan::SparkPlan; use crate::execution::tracing::{log_memory_usage, trace_begin, trace_end, with_trace}; -use crate::parquet::parquet_exec::{CometEncryptionFactory, ENCRYPTION_FACTORY_ID}; +use crate::parquet::encryption_support::{CometEncryptionFactory, ENCRYPTION_FACTORY_ID}; use datafusion_comet_proto::spark_operator::operator::OpStruct; use log::info; use once_cell::sync::Lazy; diff --git a/native/core/src/parquet/encryption_support.rs b/native/core/src/parquet/encryption_support.rs new file mode 100644 index 0000000000..1297d67ef1 --- /dev/null +++ b/native/core/src/parquet/encryption_support.rs @@ -0,0 +1,151 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::execution::operators::ExecutionError; +use crate::jvm_bridge::JVMClasses; +use arrow::datatypes::SchemaRef; +use async_trait::async_trait; +use datafusion::common::extensions_options; +use datafusion::config::EncryptionFactoryOptions; +use datafusion::error::DataFusionError; +use datafusion::execution::parquet_encryption::EncryptionFactory; +use jni::objects::{GlobalRef, JMethodID}; +use object_store::path::Path; +use parquet::encryption::decrypt::{FileDecryptionProperties, KeyRetriever}; +use parquet::encryption::encrypt::FileEncryptionProperties; +use std::sync::Arc; + +pub const ENCRYPTION_FACTORY_ID: &str = "comet.jni_kms_encryption"; + +// Options used to configure our example encryption factory +extensions_options! { + pub struct CometEncryptionConfig { + pub url_base: String, default = "file:///".into() + } +} + +#[derive(Debug)] +pub struct CometEncryptionFactory { + pub(crate) key_unwrapper: GlobalRef, +} + +/// `EncryptionFactory` is a DataFusion trait for types that generate +/// file encryption and decryption properties. +#[async_trait] +impl EncryptionFactory for CometEncryptionFactory { + async fn get_file_encryption_properties( + &self, + _options: &EncryptionFactoryOptions, + _schema: &SchemaRef, + _file_path: &Path, + ) -> Result, DataFusionError> { + Err(DataFusionError::NotImplemented( + "Comet does not support Parquet encryption yet." + .parse() + .unwrap(), + )) + } + + /// Generate file decryption properties to use when reading a Parquet file. + /// Rather than provide the AES keys directly for decryption, we set a `KeyRetriever` + /// that can determine the keys using the encryption metadata. + async fn get_file_decryption_properties( + &self, + options: &EncryptionFactoryOptions, + file_path: &Path, + ) -> Result, DataFusionError> { + let config: CometEncryptionConfig = options.to_extension_options()?; + + let full_path: String = config.url_base + file_path.as_ref(); + let key_retriever = CometKeyRetriever::new(&full_path, self.key_unwrapper.clone()) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + let decryption_properties = + FileDecryptionProperties::with_key_retriever(Arc::new(key_retriever)).build()?; + Ok(Some(decryption_properties)) + } +} + +pub struct CometKeyRetriever { + file_path: String, + key_unwrapper: GlobalRef, + get_key_method_id: JMethodID, +} + +impl CometKeyRetriever { + pub fn new(file_path: &str, key_unwrapper: GlobalRef) -> Result { + // Get JNI environment + let mut env = JVMClasses::get_env()?; + + Ok(CometKeyRetriever { + file_path: file_path.to_string(), + key_unwrapper, + get_key_method_id: env + .get_method_id( + "org/apache/comet/parquet/CometFileKeyUnwrapper", + "getKey", + "(Ljava/lang/String;[B)[B", + ) + .unwrap(), + }) + } +} + +impl KeyRetriever for CometKeyRetriever { + /// Get a data encryption key using the metadata stored in the Parquet file. + fn retrieve_key(&self, key_metadata: &[u8]) -> datafusion::parquet::errors::Result> { + use jni::{objects::JObject, signature::ReturnType}; + + // Get JNI environment + let mut env = JVMClasses::get_env() + .map_err(|e| datafusion::parquet::errors::ParquetError::General(e.to_string()))?; + + // Get the key unwrapper instance from GlobalRef + let unwrapper_instance = self.key_unwrapper.as_obj(); + + let instance: JObject = unsafe { JObject::from_raw(unwrapper_instance.as_raw()) }; + + // Convert file path to JString + let file_path_jstring = env.new_string(&self.file_path).unwrap(); + + // Convert key_metadata to JByteArray + let key_metadata_array = env.byte_array_from_slice(key_metadata).unwrap(); + + // Call instance method FileKeyUnwrapper.getKey(String, byte[]) -> byte[] + let result = unsafe { + env.call_method_unchecked( + instance, + self.get_key_method_id, + ReturnType::Array, + &[ + jni::objects::JValue::from(&file_path_jstring).as_jni(), + jni::objects::JValue::from(&key_metadata_array).as_jni(), + ], + ) + }; + + let result = result.unwrap(); + + // Extract the byte array from the result + let result_array = result.l().unwrap(); + + // Convert JObject to JByteArray and then to Vec + let byte_array: jni::objects::JByteArray = result_array.into(); + + let result_vec = env.convert_byte_array(&byte_array).unwrap(); + Ok(result_vec) + } +} diff --git a/native/core/src/parquet/mod.rs b/native/core/src/parquet/mod.rs index 876ab488bb..ca70c2fc34 100644 --- a/native/core/src/parquet/mod.rs +++ b/native/core/src/parquet/mod.rs @@ -16,6 +16,7 @@ // under the License. pub mod data_type; +pub mod encryption_support; pub mod mutable_vector; pub use mutable_vector::*; @@ -54,9 +55,8 @@ use crate::execution::serde; use crate::execution::utils::SparkArrowConvert; use crate::jvm_bridge::{jni_new_global_ref, JVMClasses}; use crate::parquet::data_type::AsBytes; -use crate::parquet::parquet_exec::{ - init_datasource_exec, CometEncryptionFactory, ENCRYPTION_FACTORY_ID, -}; +use crate::parquet::encryption_support::{CometEncryptionFactory, ENCRYPTION_FACTORY_ID}; +use crate::parquet::parquet_exec::init_datasource_exec; use crate::parquet::parquet_support::prepare_object_store_with_configs; use arrow::array::{Array, RecordBatch}; use arrow::buffer::{Buffer, MutableBuffer}; diff --git a/native/core/src/parquet/parquet_exec.rs b/native/core/src/parquet/parquet_exec.rs index cdeefb3dfb..dd73bcee13 100644 --- a/native/core/src/parquet/parquet_exec.rs +++ b/native/core/src/parquet/parquet_exec.rs @@ -16,31 +16,23 @@ // under the License. use crate::execution::operators::ExecutionError; -use crate::jvm_bridge::JVMClasses; +use crate::parquet::encryption_support::{CometEncryptionConfig, ENCRYPTION_FACTORY_ID}; use crate::parquet::parquet_support::SparkParquetOptions; use crate::parquet::schema_adapter::SparkSchemaAdapterFactory; use arrow::datatypes::{Field, SchemaRef}; -use async_trait::async_trait; -use datafusion::common::extensions_options; -use datafusion::config::{EncryptionFactoryOptions, TableParquetOptions}; +use datafusion::config::TableParquetOptions; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::physical_plan::{ FileGroup, FileScanConfigBuilder, FileSource, ParquetSource, }; use datafusion::datasource::source::DataSourceExec; -use datafusion::error::DataFusionError; use datafusion::execution::object_store::ObjectStoreUrl; -use datafusion::execution::parquet_encryption::EncryptionFactory; use datafusion::physical_expr::expressions::BinaryExpr; use datafusion::physical_expr::PhysicalExpr; use datafusion::prelude::SessionContext; use datafusion::scalar::ScalarValue; use datafusion_comet_spark_expr::EvalMode; use itertools::Itertools; -use jni::objects::{GlobalRef, JMethodID}; -use object_store::path::Path; -use parquet::encryption::decrypt::{FileDecryptionProperties, KeyRetriever}; -use parquet::encryption::encrypt::FileEncryptionProperties; use std::collections::HashMap; use std::sync::Arc; @@ -147,126 +139,6 @@ pub(crate) fn init_datasource_exec( Ok(Arc::new(DataSourceExec::new(Arc::new(file_scan_config)))) } -pub const ENCRYPTION_FACTORY_ID: &str = "comet.jni_kms_encryption"; - -// Options used to configure our example encryption factory -extensions_options! { - struct CometEncryptionConfig { - url_base: String, default = "file:///".into() - } -} -#[derive(Debug)] -pub struct CometEncryptionFactory { - pub(crate) key_unwrapper: GlobalRef, -} - -/// `EncryptionFactory` is a DataFusion trait for types that generate -/// file encryption and decryption properties. -#[async_trait] -impl EncryptionFactory for CometEncryptionFactory { - async fn get_file_encryption_properties( - &self, - _options: &EncryptionFactoryOptions, - _schema: &SchemaRef, - _file_path: &Path, - ) -> Result, DataFusionError> { - Err(DataFusionError::NotImplemented( - "Comet does not support Parquet encryption yet." - .parse() - .unwrap(), - )) - } - - /// Generate file decryption properties to use when reading a Parquet file. - /// Rather than provide the AES keys directly for decryption, we set a `KeyRetriever` - /// that can determine the keys using the encryption metadata. - async fn get_file_decryption_properties( - &self, - options: &EncryptionFactoryOptions, - file_path: &Path, - ) -> Result, DataFusionError> { - let config: CometEncryptionConfig = options.to_extension_options()?; - - let full_path: String = config.url_base + file_path.as_ref(); - let key_retriever = CometKeyRetriever::new(&full_path, self.key_unwrapper.clone()) - .map_err(|e| DataFusionError::External(Box::new(e)))?; - let decryption_properties = - FileDecryptionProperties::with_key_retriever(Arc::new(key_retriever)).build()?; - Ok(Some(decryption_properties)) - } -} - -struct CometKeyRetriever { - file_path: String, - key_unwrapper: GlobalRef, - get_key_method_id: JMethodID, -} - -impl CometKeyRetriever { - fn new(file_path: &str, key_unwrapper: GlobalRef) -> Result { - // Get JNI environment - let mut env = JVMClasses::get_env()?; - - Ok(CometKeyRetriever { - file_path: file_path.to_string(), - key_unwrapper, - get_key_method_id: env - .get_method_id( - "org/apache/comet/parquet/CometFileKeyUnwrapper", - "getKey", - "(Ljava/lang/String;[B)[B", - ) - .unwrap(), - }) - } -} - -impl KeyRetriever for CometKeyRetriever { - /// Get a data encryption key using the metadata stored in the Parquet file. - fn retrieve_key(&self, key_metadata: &[u8]) -> datafusion::parquet::errors::Result> { - use jni::{objects::JObject, signature::ReturnType}; - - // Get JNI environment - let mut env = JVMClasses::get_env() - .map_err(|e| datafusion::parquet::errors::ParquetError::General(e.to_string()))?; - - // Get the key unwrapper instance from GlobalRef - let unwrapper_instance = self.key_unwrapper.as_obj(); - - let instance: JObject = unsafe { JObject::from_raw(unwrapper_instance.as_raw()) }; - - // Convert file path to JString - let file_path_jstring = env.new_string(&self.file_path).unwrap(); - - // Convert key_metadata to JByteArray - let key_metadata_array = env.byte_array_from_slice(key_metadata).unwrap(); - - // Call instance method FileKeyUnwrapper.getKey(String, byte[]) -> byte[] - let result = unsafe { - env.call_method_unchecked( - instance, - self.get_key_method_id, - ReturnType::Array, - &[ - jni::objects::JValue::from(&file_path_jstring).as_jni(), - jni::objects::JValue::from(&key_metadata_array).as_jni(), - ], - ) - }; - - let result = result.unwrap(); - - // Extract the byte array from the result - let result_array = result.l().unwrap(); - - // Convert JObject to JByteArray and then to Vec - let byte_array: jni::objects::JByteArray = result_array.into(); - - let result_vec = env.convert_byte_array(&byte_array).unwrap(); - Ok(result_vec) - } -} - fn get_options( session_timezone: &str, case_sensitive: bool, From 9bc24fd51911297dabe9f46e1b92c74b0e2d648b Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Mon, 29 Sep 2025 10:22:02 -0400 Subject: [PATCH 11/19] Add uniform encryption test. --- .../sql/comet/ParquetEncryptionITCase.scala | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/ParquetEncryptionITCase.scala b/spark/src/test/scala/org/apache/spark/sql/comet/ParquetEncryptionITCase.scala index b7875d3100..0120b10678 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/ParquetEncryptionITCase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/ParquetEncryptionITCase.scala @@ -116,6 +116,45 @@ class ParquetEncryptionITCase extends CometTestBase with SQLTestUtils { .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") .parquet(parquetDir) + verifyParquetEncrypted(parquetDir) + + val parquetDF = spark.read.parquet(parquetDir) + assert(parquetDF.inputFiles.nonEmpty) + val readDataset = parquetDF.select("a", "b", "c") + + if (CometConf.COMET_ENABLED.get(conf)) { + checkSparkAnswerAndOperator(readDataset) + } else { + checkAnswer(readDataset, inputDF) + } + } + } + } + + test("SPARK-42114: Test of uniform parquet encryption") { + + import testImplicits._ + + withTempDir { dir => + withSQLConf( + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> cryptoFactoryClass, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> + "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", + InMemoryKMS.KEY_LIST_PROPERTY_NAME -> + s"key1: ${key1}") { + + val inputDF = spark + .range(0, 2000) + .map(i => (i, i.toString, i.toFloat)) + .repartition(10) + .toDF("a", "b", "c") + val parquetDir = new File(dir, "parquet").getCanonicalPath + inputDF.write + .option("parquet.encryption.uniform.key", "key1") + .parquet(parquetDir) + + verifyParquetEncrypted(parquetDir) + val parquetDF = spark.read.parquet(parquetDir) assert(parquetDF.inputFiles.nonEmpty) val readDataset = parquetDF.select("a", "b", "c") From bf6ad03db3ae8d9cb301d0352a244c13cdc906a4 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Tue, 30 Sep 2025 08:14:02 -0400 Subject: [PATCH 12/19] Address PR feedback. --- .../comet/parquet/NativeBatchReader.java | 3 +- .../comet/parquet/CometParquetUtils.scala | 45 +++++++------------ 2 files changed, 17 insertions(+), 31 deletions(-) diff --git a/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java b/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java index 1b86b6e3a5..84918d9335 100644 --- a/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java +++ b/common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java @@ -413,8 +413,9 @@ public void init() throws Throwable { boolean encryptionEnabled = CometParquetUtils.encryptionEnabled(conf); // Create keyUnwrapper if encryption is enabled - CometFileKeyUnwrapper keyUnwrapper = encryptionEnabled ? new CometFileKeyUnwrapper() : null; + CometFileKeyUnwrapper keyUnwrapper = null; if (encryptionEnabled) { + keyUnwrapper = new CometFileKeyUnwrapper(); keyUnwrapper.storeDecryptionKeyRetriever(file.filePath().toString(), conf); } diff --git a/common/src/main/scala/org/apache/comet/parquet/CometParquetUtils.scala b/common/src/main/scala/org/apache/comet/parquet/CometParquetUtils.scala index 56d9a7dcfd..8bcf99dbd1 100644 --- a/common/src/main/scala/org/apache/comet/parquet/CometParquetUtils.scala +++ b/common/src/main/scala/org/apache/comet/parquet/CometParquetUtils.scala @@ -21,7 +21,7 @@ package org.apache.comet.parquet import org.apache.hadoop.conf.Configuration import org.apache.parquet.crypto.DecryptionPropertiesFactory -import org.apache.parquet.crypto.keytools.KeyToolkit +import org.apache.parquet.crypto.keytools.{KeyToolkit, PropertiesDrivenCryptoFactory} import org.apache.spark.sql.internal.SQLConf object CometParquetUtils { @@ -29,20 +29,15 @@ object CometParquetUtils { private val PARQUET_FIELD_ID_READ_ENABLED = "spark.sql.parquet.fieldId.read.enabled" private val IGNORE_MISSING_PARQUET_FIELD_ID = "spark.sql.parquet.fieldId.read.ignoreMissing" - // Map of unsupported encryption configuration key-value pairs - private val UNSUPPORTED_ENCRYPTION_CONFIGS: Map[String, Set[String]] = Map( - "parquet.encryption.algorithm" -> Set("AES_GCM_CTR_V1") - // Add more unsupported configs here as needed - // "parquet.encryption.some.config" -> Set("unsupported_value1", "unsupported_value2") - ) - - // Map of encryption configurations that can only have specific allowed values - private val SUPPORTED_ENCRYPTION_CONFIGS_WHITELIST: Map[String, Set[String]] = Map( - "parquet.encryption.data.key.length.bits" -> Set("128"), - "parquet.encryption.kek.length.bits" -> Set("128") - // Add more whitelisted configs here as needed - // "parquet.encryption.some.config" -> Set("allowed_value1", "allowed_value2") - ) + // Map of encryption configuration key-value pairs that, if present, are only supported with + // these specific values. Generally, these are the default values that won't be present, + // but if they are present we want to check them. + private val SUPPORTED_ENCRYPTION_CONFIGS: Map[String, Set[String]] = Map( + // https://github.com/apache/arrow-rs/blob/main/parquet/src/encryption/ciphers.rs#L21 + KeyToolkit.DATA_KEY_LENGTH_PROPERTY_NAME -> Set(KeyToolkit.DATA_KEY_LENGTH_DEFAULT.toString), + KeyToolkit.KEK_LENGTH_PROPERTY_NAME -> Set(KeyToolkit.KEK_LENGTH_DEFAULT.toString), + // https://github.com/apache/arrow-rs/blob/main/parquet/src/file/metadata/parser.rs#L494 + PropertiesDrivenCryptoFactory.ENCRYPTION_ALGORITHM_PROPERTY_NAME -> Set("AES_GCM_V1")) def writeFieldId(conf: SQLConf): Boolean = conf.getConfString(PARQUET_FIELD_ID_WRITE_ENABLED, "false").toBoolean @@ -66,27 +61,17 @@ object CometParquetUtils { * found */ def isEncryptionConfigSupported(hadoopConf: Configuration): Boolean = { - // Check blacklist: configurations that should never have certain values - val blacklistCheck = UNSUPPORTED_ENCRYPTION_CONFIGS.forall { - case (configKey, unsupportedValues) => - val configValue = Option(hadoopConf.get(configKey)) - configValue match { - case Some(value) => !unsupportedValues.contains(value) - case None => true // Config not set, so it's supported - } - } - - // Check whitelist: configurations that can only have specific allowed values - val whitelistCheck = SUPPORTED_ENCRYPTION_CONFIGS_WHITELIST.forall { - case (configKey, allowedValues) => + // Check configurations that, if present, can only have specific allowed values + val supportedListCheck = SUPPORTED_ENCRYPTION_CONFIGS.forall { + case (configKey, supportedValues) => val configValue = Option(hadoopConf.get(configKey)) configValue match { - case Some(value) => allowedValues.contains(value) + case Some(value) => supportedValues.contains(value) case None => true // Config not set, so it's supported } } - blacklistCheck && whitelistCheck + supportedListCheck } def encryptionEnabled(hadoopConf: Configuration): Boolean = { From 7d1bf395e8fb86240e29811767d2c83686ec2c87 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Tue, 30 Sep 2025 12:52:15 -0400 Subject: [PATCH 13/19] Add benchmark. --- .../sql/benchmark/CometBenchmarkBase.scala | 40 +++++++ .../sql/benchmark/CometReadBenchmark.scala | 103 +++++++++++++++++- 2 files changed, 142 insertions(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBenchmarkBase.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBenchmarkBase.scala index 6e6c624910..1cbe27be91 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBenchmarkBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometBenchmarkBase.scala @@ -20,9 +20,14 @@ package org.apache.spark.sql.benchmark import java.io.File +import java.nio.charset.StandardCharsets +import java.util.Base64 import scala.util.Random +import org.apache.parquet.crypto.DecryptionPropertiesFactory +import org.apache.parquet.crypto.keytools.{KeyToolkit, PropertiesDrivenCryptoFactory} +import org.apache.parquet.crypto.keytools.mocks.InMemoryKMS import org.apache.spark.SparkConf import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SparkSession} @@ -120,6 +125,41 @@ trait CometBenchmarkBase extends SqlBasedBenchmark { spark.read.parquet(dir).createOrReplaceTempView("parquetV1Table") } + protected def prepareEncryptedTable( + dir: File, + df: DataFrame, + partition: Option[String] = None): Unit = { + val testDf = if (partition.isDefined) { + df.write.partitionBy(partition.get) + } else { + df.write + } + + saveAsEncryptedParquetV1Table(testDf, dir.getCanonicalPath + "/parquetV1") + } + + protected def saveAsEncryptedParquetV1Table(df: DataFrameWriter[Row], dir: String): Unit = { + val encoder = Base64.getEncoder + val footerKey = + encoder.encodeToString("0123456789012345".getBytes(StandardCharsets.UTF_8)) + val key1 = encoder.encodeToString("1234567890123450".getBytes(StandardCharsets.UTF_8)) + val cryptoFactoryClass = + "org.apache.parquet.crypto.keytools.PropertiesDrivenCryptoFactory" + withSQLConf( + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> cryptoFactoryClass, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> + "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", + InMemoryKMS.KEY_LIST_PROPERTY_NAME -> + s"footerKey: ${footerKey}, key1: ${key1}") { + df.mode("overwrite") + .option("compression", "snappy") + .option(PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, "key1: id") + .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") + .parquet(dir) + spark.read.parquet(dir).createOrReplaceTempView("parquetV1Table") + } + } + protected def makeDecimalDataFrame( values: Int, decimal: DecimalType, diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometReadBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometReadBenchmark.scala index 02b9ca5dce..a5db4f290d 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometReadBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometReadBenchmark.scala @@ -20,11 +20,16 @@ package org.apache.spark.sql.benchmark import java.io.File +import java.nio.charset.StandardCharsets +import java.util.Base64 import scala.jdk.CollectionConverters._ import scala.util.Random import org.apache.hadoop.fs.Path +import org.apache.parquet.crypto.DecryptionPropertiesFactory +import org.apache.parquet.crypto.keytools.KeyToolkit +import org.apache.parquet.crypto.keytools.mocks.InMemoryKMS import org.apache.spark.TestUtils import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.{DataFrame, SparkSession} @@ -93,6 +98,94 @@ class CometReadBaseBenchmark extends CometBenchmarkBase { } } + def encryptedScanBenchmark(values: Int, dataType: DataType): Unit = { + // Benchmarks running through spark sql. + val sqlBenchmark = + new Benchmark(s"SQL Single ${dataType.sql} Encrypted Column Scan", values, output = output) + + val encoder = Base64.getEncoder + val footerKey = + encoder.encodeToString("0123456789012345".getBytes(StandardCharsets.UTF_8)) + val key1 = encoder.encodeToString("1234567890123450".getBytes(StandardCharsets.UTF_8)) + val cryptoFactoryClass = + "org.apache.parquet.crypto.keytools.PropertiesDrivenCryptoFactory" + + withTempPath { dir => + withTempTable("parquetV1Table") { + prepareEncryptedTable( + dir, + spark.sql(s"SELECT CAST(value as ${dataType.sql}) id FROM $tbl")) + + val query = dataType match { + case BooleanType => "sum(cast(id as bigint))" + case _ => "sum(id)" + } + + sqlBenchmark.addCase("SQL Parquet - Spark") { _ => + withSQLConf( + "spark.memory.offHeap.enabled" -> "true", + "spark.memory.offHeap.size" -> "10g", + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> cryptoFactoryClass, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> + "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", + InMemoryKMS.KEY_LIST_PROPERTY_NAME -> + s"footerKey: ${footerKey}, key1: ${key1}") { + spark.sql(s"select $query from parquetV1Table").noop() + } + } + + sqlBenchmark.addCase("SQL Parquet - Comet") { _ => + withSQLConf( + "spark.memory.offHeap.enabled" -> "true", + "spark.memory.offHeap.size" -> "10g", + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> SCAN_NATIVE_COMET, + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> cryptoFactoryClass, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> + "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", + InMemoryKMS.KEY_LIST_PROPERTY_NAME -> + s"footerKey: ${footerKey}, key1: ${key1}") { + spark.sql(s"select $query from parquetV1Table").noop() + } + } + + sqlBenchmark.addCase("SQL Parquet - Comet Native DataFusion") { _ => + withSQLConf( + "spark.memory.offHeap.enabled" -> "true", + "spark.memory.offHeap.size" -> "10g", + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> SCAN_NATIVE_DATAFUSION, + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> cryptoFactoryClass, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> + "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", + InMemoryKMS.KEY_LIST_PROPERTY_NAME -> + s"footerKey: ${footerKey}, key1: ${key1}") { + spark.sql(s"select $query from parquetV1Table").noop() + } + } + + sqlBenchmark.addCase("SQL Parquet - Comet Native Iceberg Compat") { _ => + withSQLConf( + "spark.memory.offHeap.enabled" -> "true", + "spark.memory.offHeap.size" -> "10g", + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> SCAN_NATIVE_ICEBERG_COMPAT, + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> cryptoFactoryClass, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> + "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", + InMemoryKMS.KEY_LIST_PROPERTY_NAME -> + s"footerKey: ${footerKey}, key1: ${key1}") { + spark.sql(s"select $query from parquetV1Table").noop() + } + } + + sqlBenchmark.run() + } + } + } + def decimalScanBenchmark(values: Int, precision: Int, scale: Int): Unit = { val sqlBenchmark = new Benchmark( s"SQL Single Decimal(precision: $precision, scale: $scale) Column Scan", @@ -552,13 +645,20 @@ class CometReadBaseBenchmark extends CometBenchmarkBase { } } - runBenchmarkWithTable("SQL Single Numeric Column Scan", 1024 * 1024 * 15) { v => + runBenchmarkWithTable("SQL Single Numeric Column Scan", 1024 * 1024 * 128) { v => Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType) .foreach { dataType => numericScanBenchmark(v, dataType) } } + runBenchmarkWithTable("SQL Single Numeric Encrypted Column Scan", 1024 * 1024 * 128) { v => + Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType) + .foreach { dataType => + encryptedScanBenchmark(v, dataType) + } + } + runBenchmark("SQL Decimal Column Scan") { withTempTable(tbl) { import spark.implicits._ @@ -639,6 +739,7 @@ object CometReadHdfsBenchmark extends CometReadBaseBenchmark with WithHdfsCluste finally getFileSystem.delete(tempHdfsPath, true) } } + override protected def prepareTable( dir: File, df: DataFrame, From 8ba2680d25ee6f03adc2cc863548798ea9e357f0 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Wed, 1 Oct 2025 10:54:48 -0400 Subject: [PATCH 14/19] Address PR feedback related to number of hadoopConfs in a Comet plan and Factory caching. --- .../comet/parquet/CometFileKeyUnwrapper.java | 15 ++++--- .../org/apache/comet/CometExecIterator.scala | 10 ++--- .../shuffle/CometNativeShuffleWriter.scala | 1 + .../apache/spark/sql/comet/operators.scala | 44 +++++++++++-------- 4 files changed, 40 insertions(+), 30 deletions(-) diff --git a/common/src/main/java/org/apache/comet/parquet/CometFileKeyUnwrapper.java b/common/src/main/java/org/apache/comet/parquet/CometFileKeyUnwrapper.java index 305cce414f..0911901d21 100644 --- a/common/src/main/java/org/apache/comet/parquet/CometFileKeyUnwrapper.java +++ b/common/src/main/java/org/apache/comet/parquet/CometFileKeyUnwrapper.java @@ -96,11 +96,10 @@ public class CometFileKeyUnwrapper { private final ConcurrentHashMap retrieverCache = new ConcurrentHashMap<>(); - // Each hadoopConf yields a unique DecryptionPropertiesFactory. While it's unlikely that - // this Comet plan contains more than one hadoopConf, we don't want to assume that. So we'll - // provide the ability to cache more than one Factory with a map. - private final ConcurrentHashMap factoryCache = - new ConcurrentHashMap<>(); + // Cache the factory since we should be using the same hadoopConf for every file in this scan. + private DecryptionPropertiesFactory factory = null; + // Cache the hadoopConf just to assert the assumption above. + private Configuration conf = null; /** * Creates and stores a DecryptionKeyRetriever instance for the given file path. @@ -111,10 +110,12 @@ public class CometFileKeyUnwrapper { public void storeDecryptionKeyRetriever(final String filePath, final Configuration hadoopConf) { // Use DecryptionPropertiesFactory.loadFactory to get the factory and then call // getFileDecryptionProperties - DecryptionPropertiesFactory factory = factoryCache.get(hadoopConf); if (factory == null) { factory = DecryptionPropertiesFactory.loadFactory(hadoopConf); - factoryCache.put(hadoopConf, factory); + conf = hadoopConf; + } else { + // Check the assumption that all files have the same hadoopConf and thus same Factory + assert (conf == hadoopConf); } Path path = new Path(filePath); FileDecryptionProperties decryptionProperties = diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index e1c3e65623..0e0e1117f1 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -67,7 +67,8 @@ class CometExecIterator( nativeMetrics: CometMetricNode, numParts: Int, partitionIndex: Int, - encryptedFilePaths: Seq[(String, Broadcast[SerializableConfiguration])] = Seq.empty) + broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]] = None, + encryptedFilePaths: Seq[String] = Seq.empty) extends Iterator[ColumnarBatch] with Logging { @@ -112,11 +113,10 @@ class CometExecIterator( // Create keyUnwrapper if encryption is enabled val keyUnwrapper = if (encryptedFilePaths.nonEmpty) { val unwrapper = new CometFileKeyUnwrapper() + val hadoopConf: Configuration = broadcastedHadoopConfForEncryption.get.value.value - encryptedFilePaths.foreach { case (filePath, broadcastedConf) => - val hadoopConf: Configuration = broadcastedConf.value.value - unwrapper.storeDecryptionKeyRetriever(filePath, hadoopConf) - } + encryptedFilePaths.foreach(filePath => + unwrapper.storeDecryptionKeyRetriever(filePath, hadoopConf)) unwrapper } else { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala index 9d915cbc6c..43a1e5b9a0 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala @@ -104,6 +104,7 @@ class CometNativeShuffleWriter[K, V]( nativeMetrics, numParts, context.partitionId(), + broadcastedHadoopConfForEncryption = None, encryptedFilePaths = Seq.empty) while (cometIter.hasNext) { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index b4709cecdf..632f7e04bf 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -117,6 +117,7 @@ object CometExec { CometMetricNode(Map.empty), numParts, partitionIdx, + broadcastedHadoopConfForEncryption = None, encryptedFilePaths = Seq.empty) } @@ -127,8 +128,8 @@ object CometExec { nativeMetrics: CometMetricNode, numParts: Int, partitionIdx: Int, - encryptedFilePaths: Seq[(String, Broadcast[SerializableConfiguration])]) - : CometExecIterator = { + broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]], + encryptedFilePaths: Seq[String]): CometExecIterator = { val outputStream = new ByteArrayOutputStream() nativePlan.writeTo(outputStream) outputStream.close() @@ -141,6 +142,7 @@ object CometExec { nativeMetrics, numParts, partitionIdx, + broadcastedHadoopConfForEncryption, encryptedFilePaths) } @@ -213,23 +215,28 @@ abstract class CometNativeExec extends CometExec { .collectLeaves() .filter(_.isInstanceOf[CometNativeScanExec]) .map(_.asInstanceOf[CometNativeScanExec]) - val encryptedFilePaths = cometNativeScans.flatMap { scan => - // This creates a hadoopConf that brings in any SQLConf "spark.hadoop.*" configs and - // per-relation configs since different tables might have different decryption - // properties. - val hadoopConf = scan.relation.sparkSession.sessionState - .newHadoopConfWithOptions(scan.relation.options) - val encryptionEnabled = CometParquetUtils.encryptionEnabled(hadoopConf) - if (encryptionEnabled) { - // hadoopConf isn't serializable, so we have to do a broadcasted config. - val broadcastedConf = - scan.relation.sparkSession.sparkContext - .broadcast(new SerializableConfiguration(hadoopConf)) - scan.relation.inputFiles.map { filePath => (filePath, broadcastedConf) } - } else { - Seq.empty + assert( + cometNativeScans.size <= 1, + "We expect one native scan in a Comet plan since we will broadcast one hadoopConf.") + val (broadcastedHadoopConfForEncryption, encryptedFilePaths) = + cometNativeScans.headOption.fold( + (None: Option[Broadcast[SerializableConfiguration]], Seq.empty[String])) { scan => + // This creates a hadoopConf that brings in any SQLConf "spark.hadoop.*" configs and + // per-relation configs since different tables might have different decryption + // properties. + val hadoopConf = scan.relation.sparkSession.sessionState + .newHadoopConfWithOptions(scan.relation.options) + val encryptionEnabled = CometParquetUtils.encryptionEnabled(hadoopConf) + if (encryptionEnabled) { + // hadoopConf isn't serializable, so we have to do a broadcasted config. + val broadcastedConf = + scan.relation.sparkSession.sparkContext + .broadcast(new SerializableConfiguration(hadoopConf)) + (Some(broadcastedConf), scan.relation.inputFiles.toSeq) + } else { + (None, Seq.empty) + } } - } def createCometExecIter( inputs: Seq[Iterator[ColumnarBatch]], @@ -243,6 +250,7 @@ abstract class CometNativeExec extends CometExec { nativeMetrics, numParts, partitionIndex, + broadcastedHadoopConfForEncryption, encryptedFilePaths) setSubqueries(it.id, this) From e9fcca72c2894c5fd8b823f5aa2649d8dbc30450 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Wed, 1 Oct 2025 11:17:34 -0400 Subject: [PATCH 15/19] Adjust error handling. --- native/core/src/errors.rs | 9 ++++ native/core/src/parquet/encryption_support.rs | 47 ++++++++++++++----- native/core/src/parquet/parquet_exec.rs | 2 +- 3 files changed, 44 insertions(+), 14 deletions(-) diff --git a/native/core/src/errors.rs b/native/core/src/errors.rs index b3241477b8..ecac7af94e 100644 --- a/native/core/src/errors.rs +++ b/native/core/src/errors.rs @@ -185,6 +185,15 @@ impl From for DataFusionError { } } +impl From for ParquetError { + fn from(value: CometError) -> Self { + match value { + CometError::Parquet { source } => source, + _ => ParquetError::General(value.to_string()), + } + } +} + impl From for ExecutionError { fn from(value: CometError) -> Self { match value { diff --git a/native/core/src/parquet/encryption_support.rs b/native/core/src/parquet/encryption_support.rs index 1297d67ef1..ff67a3fcbd 100644 --- a/native/core/src/parquet/encryption_support.rs +++ b/native/core/src/parquet/encryption_support.rs @@ -16,7 +16,7 @@ // under the License. use crate::execution::operators::ExecutionError; -use crate::jvm_bridge::JVMClasses; +use crate::jvm_bridge::{check_exception, JVMClasses}; use arrow::datatypes::SchemaRef; use async_trait::async_trait; use datafusion::common::extensions_options; @@ -27,14 +27,16 @@ use jni::objects::{GlobalRef, JMethodID}; use object_store::path::Path; use parquet::encryption::decrypt::{FileDecryptionProperties, KeyRetriever}; use parquet::encryption::encrypt::FileEncryptionProperties; +use parquet::errors::ParquetError; use std::sync::Arc; pub const ENCRYPTION_FACTORY_ID: &str = "comet.jni_kms_encryption"; -// Options used to configure our example encryption factory extensions_options! { pub struct CometEncryptionConfig { - pub url_base: String, default = "file:///".into() + // Native side strips file down to a path (not a URI) but Spark wants the full URI, + // so we cache the prefix to stick on the front before calling over JNI + pub uri_base: String, default = "file:///".into() } } @@ -70,7 +72,7 @@ impl EncryptionFactory for CometEncryptionFactory { ) -> Result, DataFusionError> { let config: CometEncryptionConfig = options.to_extension_options()?; - let full_path: String = config.url_base + file_path.as_ref(); + let full_path: String = config.uri_base + file_path.as_ref(); let key_retriever = CometKeyRetriever::new(&full_path, self.key_unwrapper.clone()) .map_err(|e| DataFusionError::External(Box::new(e)))?; let decryption_properties = @@ -87,7 +89,6 @@ pub struct CometKeyRetriever { impl CometKeyRetriever { pub fn new(file_path: &str, key_unwrapper: GlobalRef) -> Result { - // Get JNI environment let mut env = JVMClasses::get_env()?; Ok(CometKeyRetriever { @@ -99,7 +100,9 @@ impl CometKeyRetriever { "getKey", "(Ljava/lang/String;[B)[B", ) - .unwrap(), + .map_err(|e| { + ExecutionError::GeneralError(format!("Failed to get JNI method ID: {}", e)) + })?, }) } } @@ -110,8 +113,7 @@ impl KeyRetriever for CometKeyRetriever { use jni::{objects::JObject, signature::ReturnType}; // Get JNI environment - let mut env = JVMClasses::get_env() - .map_err(|e| datafusion::parquet::errors::ParquetError::General(e.to_string()))?; + let mut env = JVMClasses::get_env()?; // Get the key unwrapper instance from GlobalRef let unwrapper_instance = self.key_unwrapper.as_obj(); @@ -119,10 +121,14 @@ impl KeyRetriever for CometKeyRetriever { let instance: JObject = unsafe { JObject::from_raw(unwrapper_instance.as_raw()) }; // Convert file path to JString - let file_path_jstring = env.new_string(&self.file_path).unwrap(); + let file_path_jstring = env + .new_string(&self.file_path) + .map_err(|e| ParquetError::General(format!("Failed to create JString: {}", e)))?; // Convert key_metadata to JByteArray - let key_metadata_array = env.byte_array_from_slice(key_metadata).unwrap(); + let key_metadata_array = env + .byte_array_from_slice(key_metadata) + .map_err(|e| ParquetError::General(format!("Failed to create byte array: {}", e)))?; // Call instance method FileKeyUnwrapper.getKey(String, byte[]) -> byte[] let result = unsafe { @@ -137,15 +143,30 @@ impl KeyRetriever for CometKeyRetriever { ) }; - let result = result.unwrap(); + // Check for Java exceptions first, before processing the result + if let Some(exception) = check_exception(&mut env).map_err(|e| { + ParquetError::General(format!("Failed to check for Java exception: {}", e)) + })? { + return Err(ParquetError::General(format!( + "Java exception during key retrieval: {}", + exception + ))); + } + + let result = + result.map_err(|e| ParquetError::General(format!("JNI method call failed: {}", e)))?; // Extract the byte array from the result - let result_array = result.l().unwrap(); + let result_array = result + .l() + .map_err(|e| ParquetError::General(format!("Failed to extract result: {}", e)))?; // Convert JObject to JByteArray and then to Vec let byte_array: jni::objects::JByteArray = result_array.into(); - let result_vec = env.convert_byte_array(&byte_array).unwrap(); + let result_vec = env + .convert_byte_array(&byte_array) + .map_err(|e| ParquetError::General(format!("Failed to convert byte array: {}", e)))?; Ok(result_vec) } } diff --git a/native/core/src/parquet/parquet_exec.rs b/native/core/src/parquet/parquet_exec.rs index dd73bcee13..0a95ec9996 100644 --- a/native/core/src/parquet/parquet_exec.rs +++ b/native/core/src/parquet/parquet_exec.rs @@ -158,7 +158,7 @@ fn get_options( table_parquet_options.crypto.configure_factory( ENCRYPTION_FACTORY_ID, &CometEncryptionConfig { - url_base: object_store_url.to_string(), + uri_base: object_store_url.to_string(), }, ); } From e8c23ab22cf1cf50b66001e27ab947a72ee9b3f5 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Mon, 6 Oct 2025 09:00:10 -0400 Subject: [PATCH 16/19] Add test with UNION. --- .../sql/comet/ParquetEncryptionITCase.scala | 59 +++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/ParquetEncryptionITCase.scala b/spark/src/test/scala/org/apache/spark/sql/comet/ParquetEncryptionITCase.scala index 0120b10678..9532df41d3 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/ParquetEncryptionITCase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/ParquetEncryptionITCase.scala @@ -323,6 +323,9 @@ class ParquetEncryptionITCase extends CometTestBase with SQLTestUtils { .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") .parquet(parquetDir2) + verifyParquetEncrypted(parquetDir1) + verifyParquetEncrypted(parquetDir2) + // Now perform a join between the two files with different encryption keys // This tests that hadoopConf properties propagate correctly to each scan val parquetDF1 = spark.read.parquet(parquetDir1).alias("f1") @@ -345,6 +348,62 @@ class ParquetEncryptionITCase extends CometTestBase with SQLTestUtils { } } + // Union ends up with two scans in the same plan, so this ensures that Comet can distinguish + // between the hadoopConfs for each relation + test("Union between files with different encryption keys") { + import testImplicits._ + + withTempDir { dir => + withSQLConf( + DecryptionPropertiesFactory.CRYPTO_FACTORY_CLASS_PROPERTY_NAME -> cryptoFactoryClass, + KeyToolkit.KMS_CLIENT_CLASS_PROPERTY_NAME -> + "org.apache.parquet.crypto.keytools.mocks.InMemoryKMS", + InMemoryKMS.KEY_LIST_PROPERTY_NAME -> + s"footerKey: ${footerKey}, key1: ${key1}, key2: ${key2}") { + + // Write first file with key1 + val inputDF1 = spark + .range(0, 100) + .map(i => (i, s"file1_${i}", i.toFloat)) + .toDF("id", "name", "value") + val parquetDir1 = new File(dir, "parquet1").getCanonicalPath + inputDF1.write + .option( + PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, + "key1: id, name, value") + .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") + .parquet(parquetDir1) + + // Write second file with key2 - same schema, different encryption key + val inputDF2 = spark + .range(100, 200) + .map(i => (i, s"file2_${i}", i.toFloat)) + .toDF("id", "name", "value") + val parquetDir2 = new File(dir, "parquet2").getCanonicalPath + inputDF2.write + .option( + PropertiesDrivenCryptoFactory.COLUMN_KEYS_PROPERTY_NAME, + "key2: id, name, value") + .option(PropertiesDrivenCryptoFactory.FOOTER_KEY_PROPERTY_NAME, "footerKey") + .parquet(parquetDir2) + + verifyParquetEncrypted(parquetDir1) + verifyParquetEncrypted(parquetDir2) + + val parquetDF1 = spark.read.parquet(parquetDir1) + val parquetDF2 = spark.read.parquet(parquetDir2) + + val unionDF = parquetDF1.union(parquetDF2) + + if (CometConf.COMET_ENABLED.get(conf)) { + checkSparkAnswerAndOperator(unionDF) + } else { + checkSparkAnswer(unionDF) + } + } + } + } + test("Test different key lengths") { import testImplicits._ From a53aa2fbe8a3ed7820ea556a1caf935b94a57143 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Mon, 6 Oct 2025 14:17:18 -0400 Subject: [PATCH 17/19] Add docs to reflect UNION discussion in PR feedback. --- .../src/main/scala/org/apache/spark/sql/comet/operators.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 632f7e04bf..de6892638a 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -218,6 +218,10 @@ abstract class CometNativeExec extends CometExec { assert( cometNativeScans.size <= 1, "We expect one native scan in a Comet plan since we will broadcast one hadoopConf.") + // If this assumption changes in the future, you can look at the commit history of #2447 + // to see how there used to be a map of relations to broadcasted confs in case multiple + // relations in a single plan. The example that came up was UNION. See discussion at: + // https://github.com/apache/datafusion-comet/pull/2447#discussion_r2406118264 val (broadcastedHadoopConfForEncryption, encryptedFilePaths) = cometNativeScans.headOption.fold( (None: Option[Broadcast[SerializableConfiguration]], Seq.empty[String])) { scan => From fbe2c96bce394f34b53b6d0f0887e8414ceceac0 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Mon, 6 Oct 2025 16:44:42 -0400 Subject: [PATCH 18/19] Address PR feedback. --- .../sql/comet/ParquetEncryptionITCase.scala | 24 +++++++++++++++---- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/ParquetEncryptionITCase.scala b/spark/src/test/scala/org/apache/spark/sql/comet/ParquetEncryptionITCase.scala index 9532df41d3..cff21ecec3 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/ParquetEncryptionITCase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/ParquetEncryptionITCase.scala @@ -237,7 +237,7 @@ class ParquetEncryptionITCase extends CometTestBase with SQLTestUtils { val readDataset = parquetDF.select("a", "b", "c") // native_datafusion and native_iceberg_compat fall back due to Arrow-rs - // https://github.com/apache/arrow-rs/blob/main/parquet/src/file/metadata/parser.rs#L414 + // https://github.com/apache/arrow-rs/blob/da9829728e2a9dffb8d4f47ffe7b103793851724/parquet/src/file/metadata/parser.rs#L494 if (CometConf.COMET_ENABLED.get(conf) && CometConf.COMET_NATIVE_SCAN_IMPL.get( conf) == SCAN_NATIVE_COMET) { checkSparkAnswerAndOperator(readDataset) @@ -512,12 +512,19 @@ class ParquetEncryptionITCase extends CometTestBase with SQLTestUtils { val byteArray = new Array[Byte](magicStringLength) val randomAccessFile = new RandomAccessFile(parquetFile, "r") try { + // Check first 4 bytes randomAccessFile.read(byteArray, 0, magicStringLength) + val firstMagicString = new String(byteArray, StandardCharsets.UTF_8) + assert(magicString == firstMagicString) + + // Check last 4 bytes + randomAccessFile.seek(randomAccessFile.length() - magicStringLength) + randomAccessFile.read(byteArray, 0, magicStringLength) + val lastMagicString = new String(byteArray, StandardCharsets.UTF_8) + assert(magicString == lastMagicString) } finally { randomAccessFile.close() } - val stringRead = new String(byteArray, StandardCharsets.UTF_8) - assert(magicString == stringRead) } } @@ -536,12 +543,19 @@ class ParquetEncryptionITCase extends CometTestBase with SQLTestUtils { val byteArray = new Array[Byte](magicStringLength) val randomAccessFile = new RandomAccessFile(parquetFile, "r") try { + // Check first 4 bytes + randomAccessFile.read(byteArray, 0, magicStringLength) + val firstMagicString = new String(byteArray, StandardCharsets.UTF_8) + assert(magicString == firstMagicString) + + // Check last 4 bytes + randomAccessFile.seek(randomAccessFile.length() - magicStringLength) randomAccessFile.read(byteArray, 0, magicStringLength) + val lastMagicString = new String(byteArray, StandardCharsets.UTF_8) + assert(magicString == lastMagicString) } finally { randomAccessFile.close() } - val stringRead = new String(byteArray, StandardCharsets.UTF_8) - assert(magicString == stringRead) } } From 257f163b23a8777a43000db0f2d1637c3e962c59 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Mon, 6 Oct 2025 17:19:45 -0400 Subject: [PATCH 19/19] Fix after merge conflicts with main. --- spark/src/main/scala/org/apache/comet/CometExecIterator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index 379e91161f..8603a7b9a8 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -143,7 +143,7 @@ class CometExecIterator( debug = COMET_DEBUG_ENABLED.get(), explain = COMET_EXPLAIN_NATIVE_ENABLED.get(), tracingEnabled, - maxTempDirectorySize = CometConf.COMET_MAX_TEMP_DIRECTORY_SIZE.get()), + maxTempDirectorySize = CometConf.COMET_MAX_TEMP_DIRECTORY_SIZE.get(), keyUnwrapper) }