diff --git a/.github/workflows/iceberg_spark_test.yml b/.github/workflows/iceberg_spark_test.yml index 74badcda5f..54ef48d786 100644 --- a/.github/workflows/iceberg_spark_test.yml +++ b/.github/workflows/iceberg_spark_test.yml @@ -164,7 +164,7 @@ jobs: -Pquick=true -x javadoc iceberg-spark-rust: - if: contains(github.event.pull_request.title, '[iceberg]') + if: contains(github.event.pull_request.title, '[iceberg-rust]') strategy: matrix: os: [ubuntu-24.04] @@ -203,7 +203,7 @@ jobs: -Pquick=true -x javadoc iceberg-spark-extensions-rust: - if: contains(github.event.pull_request.title, '[iceberg]') + if: contains(github.event.pull_request.title, '[iceberg-rust]') strategy: matrix: os: [ubuntu-24.04] @@ -242,7 +242,7 @@ jobs: -Pquick=true -x javadoc iceberg-spark-runtime-rust: - if: contains(github.event.pull_request.title, '[iceberg]') + if: contains(github.event.pull_request.title, '[iceberg-rust]') strategy: matrix: os: [ubuntu-24.04] diff --git a/native/core/src/execution/operators/iceberg_scan.rs b/native/core/src/execution/operators/iceberg_scan.rs index 2f639e9f70..e23aa49e07 100644 --- a/native/core/src/execution/operators/iceberg_scan.rs +++ b/native/core/src/execution/operators/iceberg_scan.rs @@ -130,8 +130,18 @@ impl ExecutionPlan for IcebergScanExec { partition: usize, context: Arc, ) -> DFResult { - if partition < self.file_task_groups.len() { - let tasks = &self.file_task_groups[partition]; + // In split mode (single task group), always use index 0 regardless of requested partition. + // This is because in Comet's per-partition execution model, each task builds its own plan + // with only its partition's data. The parent operator may request partition N, but this + // IcebergScanExec already contains the correct data for partition N in task_groups[0]. + let effective_partition = if self.file_task_groups.len() == 1 { + 0 + } else { + partition + }; + + if effective_partition < self.file_task_groups.len() { + let tasks = &self.file_task_groups[effective_partition]; self.execute_with_tasks(tasks.clone(), partition, context) } else { Err(DataFusionError::Execution(format!( diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 44ff20a44f..57d3fc1a49 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1132,26 +1132,27 @@ impl PhysicalPlanner { )) } OpStruct::IcebergScan(scan) => { - let required_schema: SchemaRef = - convert_spark_types_to_arrow_schema(scan.required_schema.as_slice()); + // Extract common data and single partition's file tasks + // Per-partition injection happens in Scala before sending to native + let common = scan + .common + .as_ref() + .ok_or_else(|| GeneralError("IcebergScan missing common data".into()))?; + let partition = scan + .partition + .as_ref() + .ok_or_else(|| GeneralError("IcebergScan missing partition data".into()))?; - let catalog_properties: HashMap = scan + let required_schema = + convert_spark_types_to_arrow_schema(common.required_schema.as_slice()); + let catalog_properties: HashMap = common .catalog_properties .iter() .map(|(k, v)| (k.clone(), v.clone())) .collect(); + let metadata_location = common.metadata_location.clone(); + let tasks = parse_file_scan_tasks_from_common(common, &partition.file_scan_tasks)?; - let metadata_location = scan.metadata_location.clone(); - - debug_assert!( - !scan.file_partitions.is_empty(), - "IcebergScan must have at least one file partition. This indicates a bug in Scala serialization." - ); - - let tasks = parse_file_scan_tasks( - scan, - &scan.file_partitions[self.partition as usize].file_scan_tasks, - )?; let file_task_groups = vec![tasks]; let iceberg_scan = IcebergScanExec::new( @@ -2743,15 +2744,14 @@ fn partition_data_to_struct( /// Each task contains a residual predicate that is used for row-group level filtering /// during Parquet scanning. /// -/// This function uses deduplication pools from the IcebergScan to avoid redundant parsing -/// of schemas, partition specs, partition types, name mappings, and other repeated data. -fn parse_file_scan_tasks( - proto_scan: &spark_operator::IcebergScan, +/// This function uses deduplication pools from the IcebergScanCommon to avoid redundant +/// parsing of schemas, partition specs, partition types, name mappings, and other repeated data. +fn parse_file_scan_tasks_from_common( + proto_common: &spark_operator::IcebergScanCommon, proto_tasks: &[spark_operator::IcebergFileScanTask], ) -> Result, ExecutionError> { - // Build caches upfront: for 10K tasks with 1 schema, this parses the schema - // once instead of 10K times, eliminating redundant JSON deserialization - let schema_cache: Vec> = proto_scan + // Build caches upfront from common data + let schema_cache: Vec> = proto_common .schema_pool .iter() .map(|json| { @@ -2764,7 +2764,7 @@ fn parse_file_scan_tasks( }) .collect::, _>>()?; - let partition_spec_cache: Vec>> = proto_scan + let partition_spec_cache: Vec>> = proto_common .partition_spec_pool .iter() .map(|json| { @@ -2774,7 +2774,7 @@ fn parse_file_scan_tasks( }) .collect(); - let name_mapping_cache: Vec>> = proto_scan + let name_mapping_cache: Vec>> = proto_common .name_mapping_pool .iter() .map(|json| { @@ -2784,7 +2784,7 @@ fn parse_file_scan_tasks( }) .collect(); - let delete_files_cache: Vec> = proto_scan + let delete_files_cache: Vec> = proto_common .delete_files_pool .iter() .map(|list| { @@ -2796,7 +2796,7 @@ fn parse_file_scan_tasks( "EQUALITY_DELETES" => iceberg::spec::DataContentType::EqualityDeletes, other => { return Err(GeneralError(format!( - "Invalid delete content type '{}'. This indicates a bug in Scala serialization.", + "Invalid delete content type '{}'", other ))) } @@ -2817,7 +2817,6 @@ fn parse_file_scan_tasks( }) .collect::, _>>()?; - // Partition data pool is in protobuf messages let results: Result, _> = proto_tasks .iter() .map(|proto_task| { @@ -2851,7 +2850,7 @@ fn parse_file_scan_tasks( }; let bound_predicate = if let Some(idx) = proto_task.residual_idx { - proto_scan + proto_common .residual_pool .get(idx as usize) .and_then(convert_spark_expr_to_predicate) @@ -2871,24 +2870,22 @@ fn parse_file_scan_tasks( }; let partition = if let Some(partition_data_idx) = proto_task.partition_data_idx { - // Get partition data from protobuf pool - let partition_data_proto = proto_scan + let partition_data_proto = proto_common .partition_data_pool .get(partition_data_idx as usize) .ok_or_else(|| { ExecutionError::GeneralError(format!( "Invalid partition_data_idx: {} (pool size: {})", partition_data_idx, - proto_scan.partition_data_pool.len() + proto_common.partition_data_pool.len() )) })?; - // Convert protobuf PartitionData to iceberg Struct match partition_data_to_struct(partition_data_proto) { Ok(s) => Some(s), Err(e) => { return Err(ExecutionError::GeneralError(format!( - "Failed to deserialize partition data from protobuf: {}", + "Failed to deserialize partition data: {}", e ))) } @@ -2907,14 +2904,14 @@ fn parse_file_scan_tasks( .and_then(|idx| name_mapping_cache.get(idx as usize)) .and_then(|opt| opt.clone()); - let project_field_ids = proto_scan + let project_field_ids = proto_common .project_field_ids_pool .get(proto_task.project_field_ids_idx as usize) .ok_or_else(|| { ExecutionError::GeneralError(format!( "Invalid project_field_ids_idx: {} (pool size: {})", proto_task.project_field_ids_idx, - proto_scan.project_field_ids_pool.len() + proto_common.project_field_ids_pool.len() )) })? .field_ids diff --git a/native/proto/src/lib.rs b/native/proto/src/lib.rs index 6dfe546ac8..a55657b7af 100644 --- a/native/proto/src/lib.rs +++ b/native/proto/src/lib.rs @@ -34,6 +34,7 @@ pub mod spark_partitioning { // Include generated modules from .proto files. #[allow(missing_docs)] +#[allow(clippy::large_enum_variant)] pub mod spark_operator { include!(concat!("generated", "/spark.spark_operator.rs")); } diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index 73c087cf36..3e89628d46 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -156,28 +156,34 @@ message PartitionData { repeated PartitionValue values = 1; } -message IcebergScan { +// Common data shared by all partitions in split mode (sent once, captured in closure) +message IcebergScanCommon { + // Catalog-specific configuration for FileIO (credentials, S3/GCS config, etc.) + map catalog_properties = 1; + + // Table metadata file path for FileIO initialization + string metadata_location = 2; + // Schema to read - repeated SparkStructField required_schema = 1; + repeated SparkStructField required_schema = 3; - // Catalog-specific configuration for FileIO (credentials, S3/GCS config, etc.) - map catalog_properties = 2; + // Deduplication pools (must contain ALL entries for cross-partition deduplication) + repeated string schema_pool = 4; + repeated string partition_type_pool = 5; + repeated string partition_spec_pool = 6; + repeated string name_mapping_pool = 7; + repeated ProjectFieldIdList project_field_ids_pool = 8; + repeated PartitionData partition_data_pool = 9; + repeated DeleteFileList delete_files_pool = 10; + repeated spark.spark_expression.Expr residual_pool = 11; +} - // Pre-planned file scan tasks grouped by Spark partition - repeated IcebergFilePartition file_partitions = 3; +message IcebergScan { + // Common data shared across partitions (pools, metadata, catalog props) + IcebergScanCommon common = 1; - // Table metadata file path for FileIO initialization - string metadata_location = 4; - - // Deduplication pools - shared data referenced by index from tasks - repeated string schema_pool = 5; - repeated string partition_type_pool = 6; - repeated string partition_spec_pool = 7; - repeated string name_mapping_pool = 8; - repeated ProjectFieldIdList project_field_ids_pool = 9; - repeated PartitionData partition_data_pool = 10; - repeated DeleteFileList delete_files_pool = 11; - repeated spark.spark_expression.Expr residual_pool = 12; + // Single partition's file scan tasks + IcebergFilePartition partition = 2; } // Helper message for deduplicating field ID lists diff --git a/spark/src/main/scala/org/apache/comet/serde/operator/CometIcebergNativeScan.scala b/spark/src/main/scala/org/apache/comet/serde/operator/CometIcebergNativeScan.scala index 7238f8ae8c..84297a244a 100644 --- a/spark/src/main/scala/org/apache/comet/serde/operator/CometIcebergNativeScan.scala +++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometIcebergNativeScan.scala @@ -41,6 +41,12 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit override def enabledConfig: Option[ConfigEntry[Boolean]] = None + /** Thread-local storage for split serialization data. */ + private case class SplitData(commonBytes: Array[Byte], perPartitionBytes: Array[Array[Byte]]) + private val splitDataThreadLocal = new java.lang.ThreadLocal[Option[SplitData]] { + override def initialValue(): Option[SplitData] = None + } + /** * Constants specific to Iceberg expression conversion (not in shared IcebergReflection). */ @@ -309,7 +315,7 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit contentScanTaskClass: Class[_], fileScanTaskClass: Class[_], taskBuilder: OperatorOuterClass.IcebergFileScanTask.Builder, - icebergScanBuilder: OperatorOuterClass.IcebergScan.Builder, + commonBuilder: OperatorOuterClass.IcebergScanCommon.Builder, partitionTypeToPoolIndex: mutable.HashMap[String, Int], partitionSpecToPoolIndex: mutable.HashMap[String, Int], partitionDataToPoolIndex: mutable.HashMap[String, Int]): Unit = { @@ -334,7 +340,7 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit val specIdx = partitionSpecToPoolIndex.getOrElseUpdate( partitionSpecJson, { val idx = partitionSpecToPoolIndex.size - icebergScanBuilder.addPartitionSpecPool(partitionSpecJson) + commonBuilder.addPartitionSpecPool(partitionSpecJson) idx }) taskBuilder.setPartitionSpecIdx(specIdx) @@ -415,7 +421,7 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit val typeIdx = partitionTypeToPoolIndex.getOrElseUpdate( partitionTypeJson, { val idx = partitionTypeToPoolIndex.size - icebergScanBuilder.addPartitionTypePool(partitionTypeJson) + commonBuilder.addPartitionTypePool(partitionTypeJson) idx }) taskBuilder.setPartitionTypeIdx(typeIdx) @@ -470,7 +476,7 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit val partitionDataIdx = partitionDataToPoolIndex.getOrElseUpdate( partitionDataKey, { val idx = partitionDataToPoolIndex.size - icebergScanBuilder.addPartitionDataPool(partitionDataProto) + commonBuilder.addPartitionDataPool(partitionDataProto) idx }) taskBuilder.setPartitionDataIdx(partitionDataIdx) @@ -682,6 +688,8 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit builder: Operator.Builder, childOp: Operator*): Option[OperatorOuterClass.Operator] = { val icebergScanBuilder = OperatorOuterClass.IcebergScan.newBuilder() + // commonBuilder holds shared data (pools, metadata) - built throughout this method + val commonBuilder = OperatorOuterClass.IcebergScanCommon.newBuilder() // Deduplication structures - map unique values to pool indices val schemaToPoolIndex = mutable.HashMap[AnyRef, Int]() @@ -694,6 +702,9 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit mutable.HashMap[Seq[OperatorOuterClass.IcebergDeleteFile], Int]() val residualToPoolIndex = mutable.HashMap[Option[Expr], Int]() + // Per-partition file tasks (for split serialization - injected at execution time) + val perPartitionBuilders = mutable.ArrayBuffer[OperatorOuterClass.IcebergFilePartition]() + var totalTasks = 0 // Get pre-extracted metadata from planning phase @@ -707,10 +718,10 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit } // Use pre-extracted metadata (no reflection needed) - icebergScanBuilder.setMetadataLocation(metadata.metadataLocation) + commonBuilder.setMetadataLocation(metadata.metadataLocation) metadata.catalogProperties.foreach { case (key, value) => - icebergScanBuilder.putCatalogProperties(key, value) + commonBuilder.putCatalogProperties(key, value) } // Set required_schema from output @@ -720,7 +731,7 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit .setName(attr.name) .setNullable(attr.nullable) serializeDataType(attr.dataType).foreach(field.setDataType) - icebergScanBuilder.addRequiredSchema(field.build()) + commonBuilder.addRequiredSchema(field.build()) } // Extract FileScanTasks from the InputPartitions in the RDD @@ -857,7 +868,7 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit schema, { val idx = schemaToPoolIndex.size val schemaJson = toJsonMethod.invoke(null, schema).asInstanceOf[String] - icebergScanBuilder.addSchemaPool(schemaJson) + commonBuilder.addSchemaPool(schemaJson) idx }) taskBuilder.setSchemaIdx(schemaIdx) @@ -886,7 +897,7 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit val idx = projectFieldIdsToPoolIndex.size val listBuilder = OperatorOuterClass.ProjectFieldIdList.newBuilder() projectFieldIds.foreach(id => listBuilder.addFieldIds(id)) - icebergScanBuilder.addProjectFieldIdsPool(listBuilder.build()) + commonBuilder.addProjectFieldIdsPool(listBuilder.build()) idx }) taskBuilder.setProjectFieldIdsIdx(projectFieldIdsIdx) @@ -909,7 +920,7 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit val idx = deleteFilesToPoolIndex.size val listBuilder = OperatorOuterClass.DeleteFileList.newBuilder() deleteFilesList.foreach(df => listBuilder.addDeleteFiles(df)) - icebergScanBuilder.addDeleteFilesPool(listBuilder.build()) + commonBuilder.addDeleteFilesPool(listBuilder.build()) idx }) taskBuilder.setDeleteFilesIdx(deleteFilesIdx) @@ -938,7 +949,7 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit val residualIdx = residualToPoolIndex.getOrElseUpdate( Some(residualExpr), { val idx = residualToPoolIndex.size - icebergScanBuilder.addResidualPool(residualExpr) + commonBuilder.addResidualPool(residualExpr) idx }) taskBuilder.setResidualIdx(residualIdx) @@ -950,7 +961,7 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit contentScanTaskClass, fileScanTaskClass, taskBuilder, - icebergScanBuilder, + commonBuilder, partitionTypeToPoolIndex, partitionSpecToPoolIndex, partitionDataToPoolIndex) @@ -960,7 +971,7 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit val nmIdx = nameMappingToPoolIndex.getOrElseUpdate( nm, { val idx = nameMappingToPoolIndex.size - icebergScanBuilder.addNameMappingPool(nm) + commonBuilder.addNameMappingPool(nm) idx }) taskBuilder.setNameMappingIdx(nmIdx) @@ -972,8 +983,10 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit } } + // Collect partition data for later per-partition injection + // Do NOT add to file_partitions (legacy format) val builtPartition = partitionBuilder.build() - icebergScanBuilder.addFilePartitions(builtPartition) + perPartitionBuilders += builtPartition } case _ => } @@ -1011,7 +1024,7 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit } // Calculate partition data pool size in bytes (protobuf format) - val partitionDataPoolBytes = icebergScanBuilder.getPartitionDataPoolList.asScala + val partitionDataPoolBytes = commonBuilder.getPartitionDataPoolList.asScala .map(_.getSerializedSize) .sum @@ -1022,6 +1035,16 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit s"$partitionDataPoolBytes bytes (protobuf)") } + // Embed common data into IcebergScan (partition is injected at execution time) + icebergScanBuilder.setCommon(commonBuilder.build()) + // Note: partition is NOT set here - it gets injected per-partition at execution time + + // Store per-partition data for injection at execution time + val commonBytes = commonBuilder.build().toByteArray + val perPartitionBytes = perPartitionBuilders.map(_.toByteArray).toArray + + splitDataThreadLocal.set(Some(SplitData(commonBytes, perPartitionBytes))) + builder.clearChildren() Some(builder.setIcebergScan(icebergScanBuilder).build()) } @@ -1036,10 +1059,24 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit "Metadata should have been extracted in CometScanRule.") } - // Extract metadataLocation from the native operator - val metadataLocation = nativeOp.getIcebergScan.getMetadataLocation + // Extract metadataLocation from the native operator's common data + val metadataLocation = nativeOp.getIcebergScan.getCommon.getMetadataLocation - // Create the CometIcebergNativeScanExec using the companion object's apply method - CometIcebergNativeScanExec(nativeOp, op.wrapped, op.session, metadataLocation, metadata) + // Retrieve split data from thread-local (set during convert()) + val splitData = splitDataThreadLocal.get().getOrElse { + throw new IllegalStateException( + "Programming error: split data not available. " + + "buildAndStoreSplitData() should have been called during convert().") + } + splitDataThreadLocal.remove() + + CometIcebergNativeScanExec( + nativeOp, + op.wrapped, + op.session, + metadataLocation, + metadata, + splitData.commonBytes, + splitData.perPartitionBytes) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometIcebergNativeScanExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometIcebergNativeScanExec.scala index 223ae4fbb7..5a3a6b8494 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometIcebergNativeScanExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometIcebergNativeScanExec.scala @@ -21,12 +21,14 @@ package org.apache.spark.sql.comet import scala.jdk.CollectionConverters._ +import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, Partitioning, UnknownPartitioning} +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.AccumulatorV2 import com.google.common.base.Objects @@ -49,7 +51,11 @@ case class CometIcebergNativeScanExec( override val serializedPlanOpt: SerializedPlan, metadataLocation: String, numPartitions: Int, - @transient nativeIcebergScanMetadata: CometIcebergNativeScanMetadata) + @transient nativeIcebergScanMetadata: CometIcebergNativeScanMetadata, + // Split mode: serialized IcebergScanCommon (captured in closure, sent with task) + commonData: Array[Byte] = Array.empty, + // Split mode: serialized IcebergFilePartition per partition (transient) + @transient perPartitionData: Array[Array[Byte]] = Array.empty) extends CometLeafExec { override val supportsColumnar: Boolean = true @@ -146,6 +152,42 @@ case class CometIcebergNativeScanExec( baseMetrics ++ icebergMetrics + ("num_splits" -> numSplitsMetric) } + /** Executes using split mode RDD - split data must be available. */ + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + require( + commonData.nonEmpty && perPartitionData.nonEmpty, + "IcebergScan requires split serialization data (commonData and perPartitionData)") + + val nativeMetrics = CometMetricNode.fromCometPlan(this) + CometIcebergSplitRDD(sparkContext, commonData, perPartitionData, output.length, nativeMetrics) + } + + /** + * Override convertBlock to preserve @transient fields (commonData, perPartitionData). The + * parent implementation uses makeCopy() which loses transient fields. + */ + override def convertBlock(): CometIcebergNativeScanExec = { + // Serialize the native plan if not already done + val newSerializedPlan = if (serializedPlanOpt.isEmpty) { + val bytes = CometExec.serializeNativePlan(nativeOp) + SerializedPlan(Some(bytes)) + } else { + serializedPlanOpt + } + + // Create new instance preserving transient fields + CometIcebergNativeScanExec( + nativeOp, + output, + originalPlan, + newSerializedPlan, + metadataLocation, + numPartitions, + nativeIcebergScanMetadata, + commonData, + perPartitionData) + } + override protected def doCanonicalize(): CometIcebergNativeScanExec = { CometIcebergNativeScanExec( nativeOp, @@ -182,40 +224,22 @@ case class CometIcebergNativeScanExec( object CometIcebergNativeScanExec { - /** - * Creates a CometIcebergNativeScanExec from a Spark BatchScanExec. - * - * Determines the number of partitions from Iceberg's output partitioning: - * - KeyGroupedPartitioning: Use Iceberg's partition count - * - Other cases: Use the number of InputPartitions from Iceberg's planning - * - * @param nativeOp - * The serialized native operator - * @param scanExec - * The original Spark BatchScanExec - * @param session - * The SparkSession - * @param metadataLocation - * Path to table metadata file - * @param nativeIcebergScanMetadata - * Pre-extracted Iceberg metadata from planning phase - * @return - * A new CometIcebergNativeScanExec - */ + /** Creates a CometIcebergNativeScanExec with split serialization data. */ def apply( nativeOp: Operator, scanExec: BatchScanExec, session: SparkSession, metadataLocation: String, - nativeIcebergScanMetadata: CometIcebergNativeScanMetadata): CometIcebergNativeScanExec = { + nativeIcebergScanMetadata: CometIcebergNativeScanMetadata, + commonData: Array[Byte], + perPartitionData: Array[Array[Byte]]): CometIcebergNativeScanExec = { - // Determine number of partitions from Iceberg's output partitioning - val numParts = scanExec.outputPartitioning match { - case p: KeyGroupedPartitioning => - p.numPartitions - case _ => - scanExec.inputRDD.getNumPartitions - } + // Use perPartitionData.length as the source of truth for partition count. + // This ensures consistency between the serialized per-partition data and + // the number of partitions reported by this operator. + // Note: scanExec.outputPartitioning (KeyGroupedPartitioning) may report + // a different count due to Iceberg's logical partitioning scheme. + val numParts = perPartitionData.length val exec = CometIcebergNativeScanExec( nativeOp, @@ -224,7 +248,9 @@ object CometIcebergNativeScanExec { SerializedPlan(None), metadataLocation, numParts, - nativeIcebergScanMetadata) + nativeIcebergScanMetadata, + commonData, + perPartitionData) scanExec.logicalLink.foreach(exec.setLogicalLink) exec diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometIcebergSplitRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometIcebergSplitRDD.scala new file mode 100644 index 0000000000..badc65893d --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometIcebergSplitRDD.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import org.apache.spark.{Partition, SparkContext, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.vectorized.ColumnarBatch + +import org.apache.comet.CometExecIterator +import org.apache.comet.serde.OperatorOuterClass +import org.apache.comet.serde.OperatorOuterClass.IcebergFilePartition + +/** + * Custom partition for split Iceberg serialization. Holds only bytes for this partition's file + * scan tasks. + */ +private[spark] class CometIcebergSplitPartition( + override val index: Int, + val partitionBytes: Array[Byte]) + extends Partition + +/** + * RDD for split Iceberg scan serialization that avoids sending all partition data to every task. + * + * With split serialization: + * - commonData: serialized IcebergScanCommon (pools, metadata) - captured in closure + * - perPartitionData: Array of serialized IcebergFilePartition - populates Partition objects + * + * Each task receives commonData (via closure) + partitionBytes (via Partition), combines them + * into an IcebergScan with split_mode=true, and passes to native execution. + */ +private[spark] class CometIcebergSplitRDD( + sc: SparkContext, + commonData: Array[Byte], + @transient perPartitionData: Array[Array[Byte]], + numParts: Int, + var computeFunc: (Array[Byte], CometMetricNode, Int, Int) => Iterator[ColumnarBatch]) + extends RDD[ColumnarBatch](sc, Nil) { + + override protected def getPartitions: Array[Partition] = { + perPartitionData.zipWithIndex.map { case (bytes, idx) => + new CometIcebergSplitPartition(idx, bytes) + } + } + + override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { + val partition = split.asInstanceOf[CometIcebergSplitPartition] + + val combinedPlan = + CometIcebergSplitRDD.buildCombinedPlan(commonData, partition.partitionBytes) + + // Use cached numParts to avoid triggering getPartitions() on executor + val it = computeFunc(combinedPlan, null, numParts, partition.index) + + Option(context).foreach { ctx => + ctx.addTaskCompletionListener[Unit] { _ => + it.asInstanceOf[CometExecIterator].close() + } + } + + it + } +} + +object CometIcebergSplitRDD { + + def apply( + sc: SparkContext, + commonData: Array[Byte], + perPartitionData: Array[Array[Byte]], + numOutputCols: Int, + nativeMetrics: CometMetricNode): CometIcebergSplitRDD = { + + // Create compute function that captures nativeMetrics in its closure + val computeFunc = + (combinedPlan: Array[Byte], _: CometMetricNode, numParts: Int, partIndex: Int) => { + new CometExecIterator( + CometExec.newIterId, + Seq.empty, + numOutputCols, + combinedPlan, + nativeMetrics, + numParts, + partIndex, + None, + Seq.empty) + } + + val numParts = perPartitionData.length + new CometIcebergSplitRDD(sc, commonData, perPartitionData, numParts, computeFunc) + } + + private[comet] def buildCombinedPlan( + commonBytes: Array[Byte], + partitionBytes: Array[Byte]): Array[Byte] = { + val common = OperatorOuterClass.IcebergScanCommon.parseFrom(commonBytes) + val partition = IcebergFilePartition.parseFrom(partitionBytes) + + val scanBuilder = OperatorOuterClass.IcebergScan.newBuilder() + scanBuilder.setCommon(common) + scanBuilder.setPartition(partition) + + val opBuilder = OperatorOuterClass.Operator.newBuilder() + opBuilder.setIcebergScan(scanBuilder) + + opBuilder.build().toByteArray + } +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 6f33467efe..5fb6638c5a 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -55,10 +55,75 @@ import org.apache.comet.{CometConf, CometExecIterator, CometRuntimeException, Co import org.apache.comet.CometSparkSessionExtensions.{isCometShuffleEnabled, withInfo} import org.apache.comet.parquet.CometParquetUtils import org.apache.comet.serde.{CometOperatorSerde, Compatible, Incompatible, OperatorOuterClass, SupportLevel, Unsupported} -import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, Operator} +import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, IcebergFilePartition, Operator} import org.apache.comet.serde.QueryPlanSerde.{aggExprToProto, exprToProto, supportedSortType} import org.apache.comet.serde.operator.CometSink +/** + * Helper object for building per-partition native plans with injected IcebergScan partition data. + */ +private[comet] object IcebergPartitionInjector { + + /** + * Injects partition data into an Operator tree by finding IcebergScan nodes without partition + * data and setting them using the provided map keyed by metadata_location. + * + * This handles joins over multiple Iceberg tables by matching each IcebergScan with its + * corresponding partition data based on the table's metadata_location. + * + * @param op + * The operator tree to modify + * @param partitionDataByLocation + * Map of metadataLocation -> partition bytes for this partition index + * @return + * New operator tree with partition data injected + */ + def injectPartitionData( + op: Operator, + partitionDataByLocation: Map[String, Array[Byte]]): Operator = { + val builder = op.toBuilder + + // If this is an IcebergScan without partition data, inject it based on metadata_location + if (op.hasIcebergScan) { + val scan = op.getIcebergScan + if (!scan.hasPartition && scan.hasCommon) { + val metadataLocation = scan.getCommon.getMetadataLocation + partitionDataByLocation.get(metadataLocation) match { + case Some(partitionBytes) => + val partition = IcebergFilePartition.parseFrom(partitionBytes) + val scanBuilder = scan.toBuilder + scanBuilder.setPartition(partition) + builder.setIcebergScan(scanBuilder) + case None => + // No partition data for this scan - this shouldn't happen in split mode + throw new CometRuntimeException( + s"No partition data found for Iceberg scan with metadata_location: $metadataLocation") + } + } + } + + // Recursively process children + builder.clearChildren() + op.getChildrenList.asScala.foreach { child => + builder.addChildren(injectPartitionData(child, partitionDataByLocation)) + } + + builder.build() + } + + /** + * Serializes an operator to bytes. + */ + def serializeOperator(op: Operator): Array[Byte] = { + val size = op.getSerializedSize + val bytes = new Array[Byte](size) + val codedOutput = CodedOutputStream.newInstance(bytes) + op.writeTo(codedOutput) + codedOutput.checkNoSpaceLeft() + bytes + } +} + /** * A Comet physical operator */ @@ -290,15 +355,47 @@ abstract class CometNativeExec extends CometExec { case None => (None, Seq.empty) } + // Check for IcebergScan with split mode data that needs per-partition injection. + // Only look within the current stage (stop at shuffle boundaries). + // Returns a map of metadataLocation -> perPartitionData to handle joins over + // multiple Iceberg tables. + val icebergSplitDataByLocation: Map[String, Array[Array[Byte]]] = + findAllIcebergSplitData(this) + def createCometExecIter( inputs: Seq[Iterator[ColumnarBatch]], numParts: Int, partitionIndex: Int): CometExecIterator = { + // Get the actual serialized plan - either shared or per-partition injected + // Inject partition data if we have any IcebergScans with split data + val actualPlan = if (icebergSplitDataByLocation.nonEmpty) { + // Build a map of metadataLocation -> partitionBytes for this partition index + val partitionDataByLocation = icebergSplitDataByLocation.map { + case (metadataLocation, perPartitionData) => + if (partitionIndex < perPartitionData.length) { + metadataLocation -> perPartitionData(partitionIndex) + } else { + throw new CometRuntimeException( + s"Partition index $partitionIndex out of bounds for Iceberg scan " + + s"with metadata_location $metadataLocation " + + s"(${perPartitionData.length} partitions)") + } + } + // Inject partition data into IcebergScan nodes in the native plan + val basePlan = OperatorOuterClass.Operator.parseFrom(serializedPlanCopy) + val injected = + IcebergPartitionInjector.injectPartitionData(basePlan, partitionDataByLocation) + IcebergPartitionInjector.serializeOperator(injected) + } else { + // No split data - use plan as-is + serializedPlanCopy + } + val it = new CometExecIterator( CometExec.newIterId, inputs, output.length, - serializedPlanCopy, + actualPlan, nativeMetrics, numParts, partitionIndex, @@ -440,6 +537,47 @@ abstract class CometNativeExec extends CometExec { } } + /** + * Find ALL CometIcebergNativeScanExec nodes with split mode data in the plan tree. Returns a + * map of metadataLocation -> perPartitionData for all Iceberg scans found. + * + * This supports joins over multiple Iceberg tables by collecting partition data from each scan + * and keying by metadata_location (which is unique per table). + * + * NOTE: This is only used when Iceberg scans are NOT executed via their own RDD. When Iceberg + * scans execute via CometIcebergSplitRDD, the partition data is handled there and this function + * returns an empty map. + * + * Stops at stage boundaries (shuffle exchanges, etc.) because partition indices are only valid + * within the same stage. + */ + private def findAllIcebergSplitData(plan: SparkPlan): Map[String, Array[Array[Byte]]] = { + plan match { + // Found an Iceberg scan with split data + case iceberg: CometIcebergNativeScanExec + if iceberg.commonData.nonEmpty && iceberg.perPartitionData.nonEmpty => + Map(iceberg.metadataLocation -> iceberg.perPartitionData) + + // For broadcast stages, we CAN look inside because broadcast data is replicated + // to all partitions, so partition indices align. This handles broadcast joins + // over Iceberg tables. + case bqs: BroadcastQueryStageExec => + findAllIcebergSplitData(bqs.plan) + case cbe: CometBroadcastExchangeExec => + cbe.children.flatMap(c => findAllIcebergSplitData(c)).toMap + + // Stage boundaries - stop searching (partition indices won't align after these) + case _: ShuffleQueryStageExec | _: AQEShuffleReadExec | _: CometShuffleExchangeExec | + _: CometUnionExec | _: CometTakeOrderedAndProjectExec | _: CometCoalesceExec | + _: ReusedExchangeExec | _: CometSparkToColumnarExec => + Map.empty + + // Continue searching through other operators, combining results from all children + case _ => + plan.children.flatMap(c => findAllIcebergSplitData(c)).toMap + } + } + /** * Converts this native Comet operator and its children into a native block which can be * executed as a whole (i.e., in a single JNI call) from the native side.