diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala index 951f4735444d..6217a7a44274 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosConfig.scala @@ -123,6 +123,7 @@ private[spark] object CosmosConfigNames { val WriteBulkInitialBatchSize = "spark.cosmos.write.bulk.initialBatchSize" val WriteBulkTransactionalMaxOperationsConcurrency = "spark.cosmos.write.bulk.transactional.maxOperationsConcurrency" val WriteBulkTransactionalMaxBatchesConcurrency = "spark.cosmos.write.bulk.transactional.maxBatchesConcurrency" + val WriteBulkTransactionalMarkerTtlSeconds = "spark.cosmos.write.bulk.transactional.marker.ttlSeconds" val WritePointMaxConcurrency = "spark.cosmos.write.point.maxConcurrency" val WritePatchDefaultOperationType = "spark.cosmos.write.patch.defaultOperationType" val WritePatchColumnConfigs = "spark.cosmos.write.patch.columnConfigs" @@ -1509,7 +1510,8 @@ private case class CosmosWriteBulkExecutionConfigs( private case class CosmosWriteTransactionalBulkExecutionConfigs( maxConcurrentCosmosPartitions: Option[Int] = None, maxConcurrentOperations: Option[Int] = None, - maxConcurrentBatches: Option[Int] = None) extends CosmosWriteBulkExecutionConfigsBase + maxConcurrentBatches: Option[Int] = None, + markerTtlSeconds: Option[Int] = None) extends CosmosWriteBulkExecutionConfigsBase private object CosmosWriteConfig { private val DefaultMaxRetryCount = 10 @@ -1601,6 +1603,22 @@ private object CosmosWriteConfig { helpMessage = "Max concurrent transactional batches per Cosmos partition (1..5). Controls batch-level parallelism; default 5." + "Each batch may contain multiple operations; tune together with 'spark.cosmos.write.bulk.transactional.maxOperationsConcurrency' to balance throughput and throttling.") + private val bulkTransactionalMarkerTtlSeconds = CosmosConfigEntry[Int]( + key = CosmosConfigNames.WriteBulkTransactionalMarkerTtlSeconds, + defaultValue = Option.apply(86400), + mandatory = false, + parseFromStringFunction = ttlSeconds => { + val value = ttlSeconds.toInt + if (value <= 0) { + throw new IllegalArgumentException( + s"'${CosmosConfigNames.WriteBulkTransactionalMarkerTtlSeconds}' must be a positive number of seconds, but was $value.") + } + value + }, + helpMessage = "TTL in seconds for batch marker documents used for retry ambiguity resolution in transactional bulk mode. " + + "Markers are actively deleted after each batch completes; TTL is defense-in-depth for orphan cleanup from crashed runs. " + + "Default: 86400 (24 hours). Set to a lower value (e.g., 3600) if container storage is constrained.") + private val pointWriteConcurrency = CosmosConfigEntry[Int](key = CosmosConfigNames.WritePointMaxConcurrency, mandatory = false, parseFromStringFunction = bulkMaxConcurrencyAsString => bulkMaxConcurrencyAsString.toInt, @@ -1844,18 +1862,17 @@ private object CosmosWriteConfig { if (bulkEnabledOpt.isDefined && bulkEnabledOpt.get) { if (bulkTransactionalOpt.isDefined && bulkTransactionalOpt.get) { - // Validate write strategy for transactional batches - assert(itemWriteStrategyOpt.get == ItemWriteStrategy.ItemOverwrite, - s"Transactional batches only support ItemOverwrite (upsert) write strategy. Requested: ${itemWriteStrategyOpt.get}") val maxConcurrentCosmosPartitionsOpt = CosmosConfigEntry.parse(cfg, bulkMaxConcurrentPartitions) val maxBulkTransactionalOpsConcurrencyOpt = CosmosConfigEntry.parse(cfg, bulkTransactionalMaxOpsConcurrency) val maxBulkTransactionalBatchesConcurrencyOpt = CosmosConfigEntry.parse(cfg, bulkTransactionalMaxBatchesConcurrency) + val markerTtlSecondsOpt = CosmosConfigEntry.parse(cfg, bulkTransactionalMarkerTtlSeconds) bulkExecutionConfigsOpt = Some(CosmosWriteTransactionalBulkExecutionConfigs( maxConcurrentCosmosPartitionsOpt, maxBulkTransactionalOpsConcurrencyOpt, - maxBulkTransactionalBatchesConcurrencyOpt + maxBulkTransactionalBatchesConcurrencyOpt, + markerTtlSecondsOpt )) } else { diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPatchHelper.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPatchHelper.scala index 4af483df995c..cec351b5b84a 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPatchHelper.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosPatchHelper.scala @@ -42,9 +42,11 @@ private class CosmosPatchHelper(diagnosticsConfig: DiagnosticsConfig, // There are some properties are immutable, these kind properties include: // 1. System properties : _ts, _rid, _etag // 2. id, and partitionKeyPath - if ((path.startsWith("/") && !systemProperties.contains(path.substring(1)) && IdAttributeName != path.substring(1)) - && !StringUtils.join(partitionKeyDefinition.getPaths, "").contains(path)) { - true + if (path.startsWith("/") && !systemProperties.contains(path.substring(1)) && IdAttributeName != path.substring(1)) { + // Check each partition key path individually with exact match to avoid false positives. + // e.g., "/tenant" was blocked because it's + // a substring of "/tenantId" in the joined string "/tenantId/userId/sessionId". + !partitionKeyDefinition.getPaths.contains(path) } else { false } diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosWriterBase.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosWriterBase.scala index 6fc19a36d851..e5a1e259b8c1 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosWriterBase.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/CosmosWriterBase.scala @@ -69,6 +69,7 @@ private abstract class CosmosWriterBase( new TransactionalBulkWriter( container, cosmosTargetContainerConfig, + partitionKeyDefinition, cosmosWriteConfig, diagnosticsConfig, getOutputMetricsPublisher(), @@ -153,6 +154,7 @@ private abstract class CosmosWriterBase( new TransactionalBulkWriter( container, cosmosTargetContainerConfig, + partitionKeyDefinition, cosmosWriteConfig, diagnosticsConfig, getOutputMetricsPublisher(), diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransactionalBulkWriter.scala b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransactionalBulkWriter.scala index d977d8ece28f..06c5d7c85c29 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransactionalBulkWriter.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/main/scala/com/azure/cosmos/spark/TransactionalBulkWriter.scala @@ -3,13 +3,14 @@ package com.azure.cosmos.spark // scalastyle:off underscore.import +import com.azure.cosmos.implementation.apachecommons.lang.StringUtils import com.azure.cosmos.implementation.batch.{BulkExecutorDiagnosticsTracker, CosmosBatchBulkOperation, CosmosBulkTransactionalBatchResponse, TransactionalBulkExecutor} import com.azure.cosmos.implementation.{CosmosTransactionalBulkExecutionOptionsImpl, UUIDs} -import com.azure.cosmos.models.{CosmosBatch, CosmosBatchResponse} +import com.azure.cosmos.models.{CosmosBatch, CosmosBatchItemRequestOptions, CosmosBatchPatchItemRequestOptions, CosmosBatchResponse, CosmosBulkOperations, CosmosItemOperation, PartitionKeyDefinition} import com.azure.cosmos.spark.BulkWriter.getThreadInfo import com.azure.cosmos.spark.TransactionalBulkWriter.{BulkOperationFailedException, DefaultMaxPendingOperationPerCore, emitFailureHandler, transactionalBatchInputBoundedElastic, transactionalBulkWriterInputBoundedElastic, transactionalBulkWriterRequestsBoundedElastic} import com.azure.cosmos.spark.diagnostics.DefaultDiagnostics -import com.azure.cosmos.{BridgeInternal, CosmosAsyncContainer, CosmosDiagnosticsContext, CosmosEndToEndOperationLatencyPolicyConfigBuilder, CosmosException} +import com.azure.cosmos.{BridgeInternal, CosmosAsyncContainer, CosmosDiagnosticsContext, CosmosEndToEndOperationLatencyPolicyConfigBuilder, CosmosException, SparkBridgeInternal} import reactor.core.Scannable import reactor.core.scala.publisher.SMono.PimpJFlux import reactor.core.scheduler.Scheduler @@ -24,6 +25,7 @@ import com.azure.cosmos.implementation.guava25.base.Preconditions import com.azure.cosmos.implementation.spark.{OperationContextAndListenerTuple, OperationListener} import com.azure.cosmos.models.PartitionKey import com.azure.cosmos.spark.diagnostics.{DiagnosticsContext, DiagnosticsLoader, LoggerHelper, SparkTaskContext} +import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.databind.node.ObjectNode import org.apache.spark.TaskContext import reactor.core.Disposable @@ -46,6 +48,7 @@ private class TransactionalBulkWriter ( container: CosmosAsyncContainer, containerConfig: CosmosContainerConfig, + partitionKeyDefinition: PartitionKeyDefinition, writeConfig: CosmosWriteConfig, diagnosticsConfig: DiagnosticsConfig, outputMetricsPublisher: OutputMetricsPublisherTrait, @@ -96,7 +99,7 @@ private class TransactionalBulkWriter private val transactionalBatchInputEmitter: Sinks.Many[CosmosBatchBulkOperation] = Sinks.many().unicast().onBackpressureBuffer() // for transactional batch, all rows/items from the dataframe should be grouped as one cosmos batch - private val transactionalBatchPartitionKeyScheduled = java.util.concurrent.ConcurrentHashMap.newKeySet[PartitionKey]().asScala + private val transactionalBatchPartitionKeyScheduled = java.util.concurrent.ConcurrentHashMap.newKeySet[String]().asScala private val semaphore = new Semaphore(maxPendingBatches) @@ -140,6 +143,67 @@ private class TransactionalBulkWriter ThroughputControlHelper.populateThroughputControlGroupName(cosmosTransactionalBulkExecutionOptions, writeConfig.throughputControlConfig) + private val cosmosPatchHelperOpt = writeConfig.itemWriteStrategy match { + case ItemWriteStrategy.ItemPatch | ItemWriteStrategy.ItemPatchIfExists | ItemWriteStrategy.ItemBulkUpdate => + Some(new CosmosPatchHelper(diagnosticsConfig, writeConfig.patchConfigs.get)) + case _ => None + } + + // Idempotency is determined once at construction time — the strategy and patch configs are immutable. + // Only ItemPatch/ItemPatchIfExists with Increment operations are non-idempotent (double-applying + // an increment silently corrupts counters). All other strategies produce the same result on retry. + // This flag gates the flushAndClose re-enqueue path (NOT shouldRetry). + private val batchIsIdempotent: Boolean = writeConfig.itemWriteStrategy match { + case ItemWriteStrategy.ItemPatch | ItemWriteStrategy.ItemPatchIfExists => + writeConfig.patchConfigs match { + case Some(patchConfigs) => + !patchConfigs.columnConfigsMap.values.exists(_.operationType == CosmosPatchOperationTypes.Increment) + case None => true // no patch configs → no increment → idempotent + } + case _ => true // all non-patch strategies are idempotent on retry + } + + // --- Batch Marker + // Without the marker, retry ambiguity causes false positives (inferring SUCCESS when the batch + // never committed — silently losing N-1 items) and false negatives (inferring FAIL when the + // batch actually committed — producing spurious errors). The marker is an atomic proof-of-commit: + // it is written as the last upsert in every batch, and on ambiguous retry, a point-read of the + // marker determines whether the original batch committed. After the outcome is determined, + // the marker is actively deleted. TTL is defense-in-depth for orphan cleanup from crashed runs. + // Configurable via spark.cosmos.write.bulk.transactional.marker.ttlSeconds (default 86400 = 24 hours). + private val markerTtlSeconds = transactionalBulkExecutionConfigs.markerTtlSeconds + .getOrElse(TransactionalBulkWriter.DefaultMarkerTtlSeconds) + private val batchSequenceCounter = new AtomicLong(0) + // jobRunId + sparkPartitionId + batchSeq make each marker ID globally unique. + private val (sparkPartitionId: Int, jobRunId: String) = { + val tc = TaskContext.get + if (tc != null) { + (tc.partitionId(), tc.taskAttemptId().toString) + } else { + (-1, UUIDs.nonBlockingRandomUUID().toString) + } + } + // Partition key paths (e.g., List("/tenantId", "/userId", "/sessionId")) + // needed to populate PK fields in marker documents + private val partitionKeyPaths: List[String] = partitionKeyDefinition.getPaths.asScala.toList + + // Container TTL startup check — log INFO if TTL is not enabled. + // Active deletion is the primary cleanup mechanism; TTL is defense-in-depth only. + { + try { + val containerProperties = SparkBridgeInternal.getContainerPropertiesFromCollectionCache(container) + val defaultTtl = containerProperties.getDefaultTimeToLiveInSeconds + if (defaultTtl == null) { + log.logInfo(s"Container TTL is not enabled. Marker documents will be cleaned up via active deletion. " + + s"Enable TTL (defaultTimeToLive = -1) for defense-in-depth cleanup of orphaned markers.") + } + } catch { + case ex: Exception => + log.logDebug(s"Unable to check container TTL setting: ${ex.getMessage}. " + + s"Marker cleanup will rely on active deletion.") + } + } + private val operationContext = initializeOperationContext() private def initializeOperationContext(): SparkTaskContext = { @@ -227,12 +291,63 @@ private class TransactionalBulkWriter val cosmosBatch = CosmosBatch.createCosmosBatch(bulkItemsList.get(0).partitionKey) bulkItemsList.forEach(bulkItem => { writeConfig.itemWriteStrategy match { - case ItemWriteStrategy.ItemOverwrite => cosmosBatch.upsertItemOperation(bulkItem.objectNode) - case _ => throw new IllegalStateException(s"Item write strategy ${writeConfig.itemWriteStrategy} is not supported for bulk with transactional") + case ItemWriteStrategy.ItemOverwrite => + cosmosBatch.upsertItemOperation(bulkItem.objectNode) + + case ItemWriteStrategy.ItemAppend => + cosmosBatch.createItemOperation(bulkItem.objectNode) + + case ItemWriteStrategy.ItemDelete => + cosmosBatch.deleteItemOperation(bulkItem.itemId) + + case ItemWriteStrategy.ItemDeleteIfNotModified => + val requestOptions = new CosmosBatchItemRequestOptions() + bulkItem.eTag.foreach(etag => requestOptions.setIfMatchETag(etag)) + cosmosBatch.deleteItemOperation(bulkItem.itemId, requestOptions) + + case ItemWriteStrategy.ItemOverwriteIfNotModified => + bulkItem.eTag match { + case Some(etag) => + val requestOptions = new CosmosBatchItemRequestOptions() + requestOptions.setIfMatchETag(etag) + cosmosBatch.replaceItemOperation(bulkItem.itemId, bulkItem.objectNode, requestOptions) + case None => + cosmosBatch.createItemOperation(bulkItem.objectNode) + } + + case ItemWriteStrategy.ItemPatch | ItemWriteStrategy.ItemPatchIfExists => + val patchOps = cosmosPatchHelperOpt.get.createCosmosPatchOperations( + bulkItem.itemId, partitionKeyDefinition, bulkItem.objectNode) + val requestOptions = new CosmosBatchPatchItemRequestOptions() + val patchConfigs = writeConfig.patchConfigs.get + if (patchConfigs.filter.isDefined && !StringUtils.isEmpty(patchConfigs.filter.get)) { + requestOptions.setFilterPredicate(patchConfigs.filter.get) + } + cosmosBatch.patchItemOperation(bulkItem.itemId, patchOps, requestOptions) + + case _ => + throw new IllegalStateException( + s"Item write strategy ${writeConfig.itemWriteStrategy} is not supported for bulk with transactional") } }) - scheduleBatch(cosmosBatch) + // Append marker as the last upsert in the batch for retry ambiguity resolution. + // Skip marker if batch already has 100 items + val markerId = if (bulkItemsList.size() < TransactionalBulkWriter.MaxOperationsPerBatch) { + val batchSeq = batchSequenceCounter.incrementAndGet() + val id = s"__tbw:$jobRunId:$sparkPartitionId:$batchSeq" + val markerNode = buildMarkerDocument(id, bulkItemsList.get(0).objectNode) + cosmosBatch.upsertItemOperation(markerNode) + Some(id) + } else { + // batch has 100 business items — skip marker, fall back to shouldIgnore-only + log.logInfo(s"Batch for PK '${bulkItemsList.get(0).partitionKey}' has " + + s"${bulkItemsList.size()} items (server limit). Marker skipped — using shouldIgnore-only inference. " + + s"Context: ${operationContext.toString} $getThreadInfo") + None + } + + scheduleBatch(cosmosBatch, bulkItemsList.asScala.toList, markerId) SMono.empty } }) @@ -294,7 +409,9 @@ private class TransactionalBulkWriter batchOperation.cosmosBatchBulkOperation, None, isGettingRetried, - Some(cosmosException)) + Some(cosmosException), + batchOperation.originalItems, + batchOperation.markerId) case _ => log.logWarning( s"unexpected failure: partitionKeyValue=[" + @@ -310,11 +427,19 @@ private class TransactionalBulkWriter batchOperation.cosmosBatchBulkOperation, Some(resp.getResponse), isGettingRetried, - None) + None, + batchOperation.originalItems, + batchOperation.markerId) } else { - // no error case - outputMetricsPublisher.trackWriteOperation(resp.getResponse.size(), None) - totalSuccessfulIngestionMetrics.addAndGet(resp.getResponse.size()) + // Happy path: batch succeeded on first attempt + // Use originalItems.size (business items only), not resp.getResponse.size() + // which includes the internal marker operation and would inflate metrics. + outputMetricsPublisher.trackWriteOperation(batchOperation.originalItems.size, None) + totalSuccessfulIngestionMetrics.addAndGet(batchOperation.originalItems.size) + // Best-effort marker cleanup — marker is no longer needed + deleteMarkerBestEffort( + batchOperation.markerId, + batchOperation.cosmosBatchBulkOperation.getPartitionKeyValue) } } } @@ -350,11 +475,23 @@ private class TransactionalBulkWriter Preconditions.checkState(!closed.get()) throwIfCapturedExceptionExists() - val transactionalBulkItem = TransactionalBulkItem(partitionKeyValue, objectNode) + val itemId = getId(objectNode) + val eTag = getETag(objectNode) + + val transactionalBulkItem = TransactionalBulkItem(partitionKeyValue, objectNode, itemId, eTag) transactionalBulkInputEmitter.emitNext(transactionalBulkItem, emitFailureHandler) } - private def scheduleBatch(cosmosBatch: CosmosBatch): Unit = { + private def getId(objectNode: ObjectNode): String = { + val idField = objectNode.get(CosmosConstants.Properties.Id) + if (idField == null || !idField.isTextual) { + throw new IllegalArgumentException( + s"The required '${CosmosConstants.Properties.Id}' field is missing or not a string in the document being written to Cosmos DB.") + } + idField.textValue() + } + + private def scheduleBatch(cosmosBatch: CosmosBatch, originalItems: List[TransactionalBulkItem], markerId: Option[String]): Unit = { Preconditions.checkState(!closed.get()) throwIfCapturedExceptionExists() @@ -364,11 +501,14 @@ private class TransactionalBulkWriter new OperationContext( cosmosBatch.getPartitionKeyValue, 1, - monotonicOperationCounter.incrementAndGet()) - if (!transactionalBatchPartitionKeyScheduled.add(cosmosBatch.getPartitionKeyValue)) { - log.logError(s"There are already existing cosmos batch operation scheduled for partition key ${cosmosBatch.getPartitionKeyValue}," + - s" transactional is not guaranteed, fail") - SMono.error(new IllegalStateException(s"Transactional is not guaranteed for partition key ${cosmosBatch.getPartitionKeyValue}")) + monotonicOperationCounter.incrementAndGet(), + batchIsIdempotent) + val partitionKeyString = cosmosBatch.getPartitionKeyValue.toString + if (!transactionalBatchPartitionKeyScheduled.add(partitionKeyString)) { + log.logError(s"Partition key value '$partitionKeyString' has already been scheduled in this writer instance. " + + s"This indicates a bug in the data distribution or ordering pipeline. " + + s"Atomicity guarantee may be violated for this partition key value. " + + s"Context: ${operationContext.toString} $getThreadInfo") } val numberOfIntervalsWithIdenticalActiveOperationSnapshots = new AtomicLong(0) @@ -400,10 +540,10 @@ private class TransactionalBulkWriter pendingCosmosBatchSnapshot = pendingBatchRetries.clone() } - val cnt = totalScheduledMetrics.getAndAdd(cosmosBatch.getOperations.size()) + val cnt = totalScheduledMetrics.getAndAdd(originalItems.size) log.logTrace(s"total scheduled $cnt, Context: ${operationContext.toString} $getThreadInfo") - scheduleBatchInternal(CosmosBatchOperation(cosmosBatchBulkOperation, operationContext)) + scheduleBatchInternal(CosmosBatchOperation(cosmosBatchBulkOperation, operationContext, originalItems, markerId)) } private def scheduleBatchInternal(cosmosBatchOperation: CosmosBatchOperation): Unit = { @@ -428,7 +568,9 @@ private class TransactionalBulkWriter cosmosBatchBulkOperation: CosmosBatchBulkOperation, cosmosBatchResponse: Option[CosmosBatchResponse], isGettingRetried: AtomicBoolean, - responseException: Option[CosmosException] + responseException: Option[CosmosException], + originalItems: List[TransactionalBulkItem], + markerId: Option[String] ) : Unit = { val exceptionMessage = cosmosBatchResponse match { @@ -459,7 +601,87 @@ private class TransactionalBulkWriter s"$effectiveStatusCode:$effectiveSubStatusCode, " + s"Context: ${operationContext.toString} $getThreadInfo") - if (shouldRetry(effectiveStatusCode, effectiveSubStatusCode, operationContext)) { + if (shouldIgnoreOnRetry(operationContext, cosmosBatchResponse)) { + // shouldIgnore matched — but we need to verify before inferring SUCCESS. + markerId match { + case Some(id) => + // Normal batch (has marker) -> verify via marker point-read + val verificationOutcome = verifyBatchCommit(id, cosmosBatchBulkOperation.getPartitionKeyValue) + verificationOutcome match { + case Committed => + log.logInfo(s"for partitionKeyValue=[${operationContext.partitionKeyValueInput}], " + + s"marker verification confirmed COMMITTED. markerId='$id', " + + s"statusCode='$effectiveStatusCode:$effectiveSubStatusCode', " + + s"attemptNumber=${operationContext.attemptNumber}, " + + s"Context: {${operationContext.toString}} $getThreadInfo") + outputMetricsPublisher.trackWriteOperation(originalItems.size, None) + totalSuccessfulIngestionMetrics.addAndGet(originalItems.size) + deleteMarkerBestEffort(markerId, cosmosBatchBulkOperation.getPartitionKeyValue) + + case NotCommitted => + log.logInfo(s"for partitionKeyValue=[${operationContext.partitionKeyValueInput}], " + + s"marker verification found NOT COMMITTED (marker absent). " + + s"shouldIgnore error was from external process, not our batch. markerId='$id', " + + s"statusCode='$effectiveStatusCode:$effectiveSubStatusCode', " + + s"attemptNumber=${operationContext.attemptNumber}, " + + s"Context: {${operationContext.toString}} $getThreadInfo") + // FAIL — the original batch did not commit + val message = s"Batch not committed (marker absent after shouldIgnore match) - " + + s"statusCode=[$effectiveStatusCode:$effectiveSubStatusCode] " + + s"partitionKeyValue=[${operationContext.partitionKeyValueInput}]" + captureIfFirstFailure( + new BulkOperationFailedException(effectiveStatusCode, effectiveSubStatusCode, message, null)) + cancelWork() + + case Inconclusive => + // Verification read itself failed — consume one retry attempt. + // Retry eligibility is based solely on remaining attempt budget, NOT on the original + // batch status code. The original batch returned a shouldIgnore-eligible code (e.g., 409 + // for ItemAppend) which is NOT transient — passing it to shouldRetry() would incorrectly + // reject the retry even though attempts remain. The verification read failed transiently; + // the only question is whether we have retry budget left. + log.logWarning(s"for partitionKeyValue=[${operationContext.partitionKeyValueInput}], " + + s"marker verification inconclusive (read failed). Will retry. markerId='$id', " + + s"attemptNumber=${operationContext.attemptNumber}, " + + s"Context: {${operationContext.toString}} $getThreadInfo") + if (operationContext.attemptNumber < writeConfig.maxRetryCount) { + val batchOperationRetry = CosmosBatchOperation( + cosmosBatchBulkOperation, + new OperationContext( + operationContext.partitionKeyValueInput, + operationContext.attemptNumber + 1, + operationContext.sequenceNumber, + operationContext.isIdempotent), + originalItems, + markerId + ) + this.scheduleRetry( + trackPendingRetryAction = () => pendingBatchRetries.put(cosmosBatchBulkOperation.getPartitionKeyValue, batchOperationRetry).isEmpty, + clearPendingRetryAction = () => pendingBatchRetries.remove(cosmosBatchBulkOperation.getPartitionKeyValue).isDefined, + batchOperationRetry, + effectiveStatusCode) + isGettingRetried.set(true) + } else { + val message = s"Marker verification inconclusive and retries exhausted - " + + s"statusCode=[$effectiveStatusCode:$effectiveSubStatusCode] " + + s"partitionKeyValue=[${operationContext.partitionKeyValueInput}]" + captureIfFirstFailure( + new BulkOperationFailedException(effectiveStatusCode, effectiveSubStatusCode, message, null)) + cancelWork() + } + } + + case None => + // batch (100 items, marker skipped) -> shouldIgnore-only inference + log.logInfo(s"for partitionKeyValue=[${operationContext.partitionKeyValueInput}], " + + s"inferred SUCCESS on retry via shouldIgnore. " + + s"statusCode='$effectiveStatusCode:$effectiveSubStatusCode', " + + s"attemptNumber=${operationContext.attemptNumber}, " + + s"Context: {${operationContext.toString}} $getThreadInfo") + outputMetricsPublisher.trackWriteOperation(originalItems.size, None) + totalSuccessfulIngestionMetrics.addAndGet(originalItems.size) + } + } else if (shouldRetry(effectiveStatusCode, effectiveSubStatusCode, operationContext)) { // requeue log.logWarning(s"for partitionKeyValue=[${operationContext.partitionKeyValueInput}], " + s"encountered status code '$effectiveStatusCode:$effectiveSubStatusCode', will retry! " + @@ -471,7 +693,10 @@ private class TransactionalBulkWriter new OperationContext( operationContext.partitionKeyValueInput, operationContext.attemptNumber + 1, - operationContext.sequenceNumber) + operationContext.sequenceNumber, + operationContext.isIdempotent), + originalItems, + markerId ) this.scheduleRetry( @@ -611,14 +836,17 @@ private class TransactionalBulkWriter if (maxAllowedIntervalWithoutAnyProgressExceeded) { val exception = { // order by batch sequence number - // then return all operations inside the batch + // then return all original items for re-scheduling + // Use originalItems instead of batch.getOperations to avoid NPE on delete operations + // (getItem[ObjectNode] returns null for delete operations) val retriableRemainingOperations = if (allowRetryOnNewBulkWriterInstance) { Some( (pendingRetriesSnapshot ++ activeOperationsSnapshot) .toList .sortBy(op => op._2.operationContext.sequenceNumber) - .map(batchOperationPartitionKeyPair => batchOperationPartitionKeyPair._2.cosmosBatchBulkOperation.getCosmosBatch) - .flatMap(batch => batch.getOperations.asScala) + .flatMap(batchOperationPartitionKeyPair => batchOperationPartitionKeyPair._2.originalItems) + .map(item => CosmosBulkOperations.getUpsertItemOperation( + item.objectNode, item.partitionKey).asInstanceOf[CosmosItemOperation]) ) } else { None @@ -709,14 +937,27 @@ private class TransactionalBulkWriter activeOperationsSnapshot.foreach(operationPartitionKeyPair => { if (activeBatches.contains(operationPartitionKeyPair._1)) { - // re-validating whether the operation is still active - if so, just re-enqueue another retry - // this is harmless - because all bulkItemOperations from Spark connector are always idempotent - // For FAIL_NON_SERIALIZED, will keep retry, while for other errors, use the default behavior - transactionalBatchInputEmitter.emitNext(operationPartitionKeyPair._2.cosmosBatchBulkOperation, TransactionalBulkWriter.emitFailureHandler) - log.logWarning(s"Re-enqueued a retry for pending active batch task " - + s"(${operationPartitionKeyPair._1})' " - + s"- Attempt: ${numberOfIntervalsWithIdenticalActiveOperationSnapshots.get} - " - + s"Context: ${operationContext.toString} $getThreadInfo") + val batchOp = operationPartitionKeyPair._2 + if (!batchOp.operationContext.isIdempotent) { + // Skip re-enqueue for non-idempotent operations (e.g., increment patch). + // The original batch may still be in-flight — re-enqueuing would cause + // concurrent double-execution, silently corrupting data (e.g., counters + // incremented twice). Allow no-progress detection to handle instead. + log.logWarning(s"Skipping re-enqueue for non-idempotent batch operation " + + s"(${operationPartitionKeyPair._1}). " + + s"Allowing no-progress detection to handle. " + + s"- Attempt: ${numberOfIntervalsWithIdenticalActiveOperationSnapshots.get} - " + + s"Context: ${operationContext.toString} $getThreadInfo") + } else { + // re-validating whether the operation is still active - if so, just re-enqueue another retry + // this is safe for idempotent operations - double execution produces the same result + transactionalBatchInputEmitter.emitNext( + batchOp.cosmosBatchBulkOperation, TransactionalBulkWriter.emitFailureHandler) + log.logWarning(s"Re-enqueued a retry for pending active batch task " + + s"(${operationPartitionKeyPair._1})' " + + s"- Attempt: ${numberOfIntervalsWithIdenticalActiveOperationSnapshots.get} - " + + s"Context: ${operationContext.toString} $getThreadInfo") + } } }) } @@ -748,9 +989,14 @@ private class TransactionalBulkWriter transactionalBatchInputEmitter.emitComplete(TransactionalBulkWriter.emitFailureHandlerForComplete) throwIfCapturedExceptionExists() - assume(activeBatchTasks.get() <= 0) - assume(activeBatches.isEmpty) - assume(semaphore.availablePermits() >= maxPendingBatches) + if (activeBatchTasks.get() > 0) { + log.logWarning(s"flushAndClose completed but activeBatchTasks=${activeBatchTasks.get()} > 0. " + + s"Context: ${operationContext.toString} $getThreadInfo") + } + if (activeBatches.nonEmpty) { + log.logWarning(s"flushAndClose completed but activeBatches is not empty (size=${activeBatches.size}). " + + s"Context: ${operationContext.toString} $getThreadInfo") + } if (totalScheduledMetrics.get() != totalSuccessfulIngestionMetrics.get) { log.logWarning(s"flushAndClose completed with no error but inconsistent total success and " + @@ -816,12 +1062,169 @@ private class TransactionalBulkWriter batchSubscriptionDisposable.dispose() } + // Builds a minimal marker document with id + ttl + partition key fields. + // The marker's PK field values are copied from the first business item in the batch + // so the marker lands in the same logical partition. + private def buildMarkerDocument(markerId: String, firstBusinessItem: ObjectNode): ObjectNode = { + val markerNode = TransactionalBulkWriter.markerObjectMapper.createObjectNode() + markerNode.put("id", markerId) + markerNode.put("ttl", markerTtlSeconds) + partitionKeyPaths.foreach(path => { + val fieldName = path.stripPrefix("/") + val value = firstBusinessItem.get(fieldName) + if (value != null) { + markerNode.set(fieldName, value.deepCopy()) + } + }) + markerNode + } + + // Best-effort marker cleanup — single attempt, no retry. + // The batch outcome is already determined before this is called — + // the delete is purely cleanup, not a correctness operation. + private def deleteMarkerBestEffort(markerId: Option[String], partitionKeyValue: PartitionKey): Unit = { + markerId match { + case Some(id) => + // Fire-and-forget: do not block the response handler thread. + container.deleteItem(id, partitionKeyValue) + .doOnSuccess((_: Any) => + log.logDebug(s"Marker '$id' deleted successfully. " + + s"Context: ${operationContext.toString} $getThreadInfo")) + .doOnError((ex: Throwable) => + // Best-effort: log warning, do not retry, do not propagate. + // If TTL is enabled, the marker will eventually expire. + // If TTL is disabled, the marker stays as a ~100-byte orphan — no correctness impact. + log.logWarning(s"Failed to delete marker '$id' (best-effort cleanup). " + + s"Marker is inert and will not be read again. " + + s"Exception: ${ex.getMessage}, " + + s"Context: ${operationContext.toString} $getThreadInfo")) + .onErrorResume((_: Throwable) => reactor.core.publisher.Mono.empty()) + .subscribeOn(Schedulers.boundedElastic()) + .subscribe() + case None => // No marker — nothing to delete + } + } + + // Marker verification outcomes + private sealed trait MarkerVerificationOutcome + private case object Committed extends MarkerVerificationOutcome // Marker present -> batch committed + private case object NotCommitted extends MarkerVerificationOutcome // Marker absent -> batch did not commit + private case object Inconclusive extends MarkerVerificationOutcome // Verification read itself failed + + // Marker verification — the ONLY decision signal for ambiguous retries. + // Returns Committed (marker present), NotCommitted (marker absent), or + // Inconclusive (verification read itself failed with a transient error). + private def verifyBatchCommit( + markerId: String, + partitionKeyValue: PartitionKey + ): MarkerVerificationOutcome = { + try { + // Bounded timeout prevents hanging on network stall, thread starvation, or extended + // SDK-internal retry loops. Timeout is treated as Inconclusive, consuming a retry attempt. + container.readItem(markerId, partitionKeyValue, classOf[ObjectNode]) + .block(TransactionalBulkWriter.MarkerVerificationTimeout) + // 200 OK — marker exists -> batch committed + Committed + } catch { + case cosmosEx: CosmosException if cosmosEx.getStatusCode == 404 => + // 404 Not Found — marker does not exist -> batch did NOT commit + NotCommitted + case cosmosEx: CosmosException + if Exceptions.canBeTransientFailure(cosmosEx.getStatusCode, cosmosEx.getSubStatusCode) => + // Transient error on the verification read itself — inconclusive + log.logWarning(s"Marker verification read failed with transient error " + + s"${cosmosEx.getStatusCode}:${cosmosEx.getSubStatusCode} for marker '$markerId'. " + + s"Context: ${operationContext.toString} $getThreadInfo") + Inconclusive + case ex: Exception => + // Unexpected error (including block() timeout -> IllegalStateException) — treat as inconclusive + log.logWarning(s"Marker verification read failed unexpectedly for marker '$markerId'. " + + s"Exception: ${ex.getMessage}, " + + s"Context: ${operationContext.toString} $getThreadInfo") + Inconclusive + } + } + + // Restricted subset of BulkWriter's shouldIgnore — excludes 412 (Precondition Failed) + // because 412 is ambiguous on retry for batch operations. + private def shouldIgnore(statusCode: Int, subStatusCode: Int): Boolean = { + writeConfig.itemWriteStrategy match { + case ItemWriteStrategy.ItemAppend => Exceptions.isResourceExistsException(statusCode) + case ItemWriteStrategy.ItemPatchIfExists => Exceptions.isNotFoundExceptionCore(statusCode, subStatusCode) + case ItemWriteStrategy.ItemDelete => Exceptions.isNotFoundExceptionCore(statusCode, subStatusCode) + // 412 is excluded — ambiguous on retry for batch operations + case ItemWriteStrategy.ItemDeleteIfNotModified => Exceptions.isNotFoundExceptionCore(statusCode, subStatusCode) + case ItemWriteStrategy.ItemOverwriteIfNotModified => + Exceptions.isResourceExistsException(statusCode) || + Exceptions.isNotFoundExceptionCore(statusCode, subStatusCode) + case _ => false + } + } + + // Batch-level shouldIgnore with retry + first-operation guards. + // Only infers SUCCESS when: + // 1. This is a retry (attemptNumber > 1) + // 2. We have per-operation results + // 3. The first non-424 result is on the FIRST operation (index 0) + // 4. shouldIgnore returns true for that operation's status code + // 5. Strategy is NOT ItemBulkUpdate (retries rebuild the batch) + private def shouldIgnoreOnRetry( + operationContext: OperationContext, + cosmosBatchResponse: Option[CosmosBatchResponse] + ): Boolean = { + // Condition 0: Must NOT be ItemBulkUpdate — retries rebuild the batch with fresh ETags, + // so the retry batch is different from the original. shouldIgnore inference is invalid. + if (writeConfig.itemWriteStrategy == ItemWriteStrategy.ItemBulkUpdate) return false + + // Condition 1: Must be a retry (not first attempt) + if (operationContext.attemptNumber <= 1) return false + + // Condition 2: Must have per-operation results + val response = cosmosBatchResponse.getOrElse(return false) + val results = response.getResults.asScala + + // Condition 3: Find the first non-424 result (424 = Failed Dependency = rolled back) + val firstNon424 = results.zipWithIndex.find { case (result, _) => + result.getStatusCode != 424 + } + + firstNon424 match { + case Some((result, index)) => + // Condition 4: Must be the FIRST operation in the batch (index 0) + if (index != 0) return false + + // Condition 5: Check shouldIgnore for the strategy + val returnValue = shouldIgnore(result.getStatusCode, result.getSubStatusCode) + if (returnValue) { + log.logDebug(s"shouldIgnoreOnRetry: true for statusCode " + + s"'${result.getStatusCode}:${result.getSubStatusCode}' on first operation, " + + s"attemptNumber=${operationContext.attemptNumber}, " + + s"Context: ${operationContext.toString} $getThreadInfo") + } + returnValue + + case None => false // All results are 424 — shouldn't happen + } + } + private def shouldRetry(statusCode: Int, subStatusCode: Int, operationContext: OperationContext): Boolean = { var returnValue = false if (operationContext.attemptNumber < writeConfig.maxRetryCount) { - returnValue = Exceptions.canBeTransientFailure(statusCode, subStatusCode) || - statusCode == 0 // Gateway mode reports inability to connect due to PoolAcquirePendingLimitException as status code 0 + returnValue = writeConfig.itemWriteStrategy match { + // Upsert can return 404/0 in rare cases (when due to TTL expiration there is a race condition) + case ItemWriteStrategy.ItemOverwrite => + Exceptions.canBeTransientFailure(statusCode, subStatusCode) || + statusCode == 0 || // Gateway mode: PoolAcquirePendingLimitException + Exceptions.isNotFoundExceptionCore(statusCode, subStatusCode) + case _ => + Exceptions.canBeTransientFailure(statusCode, subStatusCode) || + statusCode == 0 // Gateway mode: PoolAcquirePendingLimitException + } } + // NOTE: isIdempotent does NOT gate shouldRetry. TransactionalBulkWriter follows BulkWriter's retry behavior — retrying even non-idempotent + // operations (e.g., increment) and accepting the double-application risk. This matches BulkWriter + // which retries ItemPatch (including increment) with no special handling. + // The isIdempotent flag ONLY gates the flushAndClose re-enqueue path. log.logDebug(s"Should retry statusCode '$statusCode:$subStatusCode' -> $returnValue, " + s"Context: ${operationContext.toString} $getThreadInfo") @@ -851,17 +1254,26 @@ private class TransactionalBulkWriter val partitionKeyValueInput: PartitionKey, val attemptNumber: Int, val sequenceNumber: Long, - /** starts from 1 * */) + /** starts from 1 * */ + val isIdempotent: Boolean = true) { override def equals(obj: Any): Boolean = partitionKeyValueInput.equals(obj) override def hashCode(): Int = partitionKeyValueInput.hashCode() override def toString: String = { - partitionKeyValueInput.toString + s", attemptNumber = $attemptNumber" + partitionKeyValueInput.toString + s", attemptNumber = $attemptNumber, isIdempotent = $isIdempotent" } } - private case class CosmosBatchOperation(cosmosBatchBulkOperation: CosmosBatchBulkOperation, operationContext: OperationContext) - private case class TransactionalBulkItem(partitionKey: PartitionKey, objectNode: ObjectNode) + private case class CosmosBatchOperation( + cosmosBatchBulkOperation: CosmosBatchBulkOperation, + operationContext: OperationContext, + originalItems: List[TransactionalBulkItem], + markerId: Option[String] = None) + private case class TransactionalBulkItem( + partitionKey: PartitionKey, + objectNode: ObjectNode, + itemId: String, + eTag: Option[String] = None) } private object TransactionalBulkWriter { @@ -870,6 +1282,17 @@ private object TransactionalBulkWriter { private val maxDelayOn408RequestTimeoutInMs = 3000 private val minDelayOn408RequestTimeoutInMs = 500 private val maxItemOperationsToShowInErrorMessage = 10 + // Cosmos DB server limit: maximum 100 operations per transactional batch + private val MaxOperationsPerBatch = 100 + // Default TTL for marker documents (orphan cleanup from crashed runs) + // 24 hours = 86400 seconds. Primary cleanup is active deletion, not TTL. + private val DefaultMarkerTtlSeconds = 86400 + // Bounded timeout for marker verification point-read. Prevents hanging on network stall, + // thread starvation, or extended SDK-internal retry loops. A timeout is treated as Inconclusive. + // 10 seconds is generous for a single point-read but covers SDK-internal 429 retries. + private val MarkerVerificationTimeout = java.time.Duration.ofSeconds(10) + // Shared ObjectMapper for building marker documents — thread-safe, reused across all instances + private val markerObjectMapper = new ObjectMapper() private val TRANSACTIONAL_BULK_WRITER_REQUESTS_BOUNDED_ELASTIC_THREAD_NAME = "transactional-bulk-writer-requests-bounded-elastic" private val TRANSACTIONAL_BULK_WRITER_INPUT_BOUNDED_ELASTIC_THREAD_NAME = "transactional-bulk-writer-input-bounded-elastic" private val TRANSACTIONAL_BATCH_INPUT_BOUNDED_ELASTIC_THREAD_NAME = "transactional-batch-input-bounded-elastic" diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2ETransactionalBulkWriterITest.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2ETransactionalBulkWriterITest.scala new file mode 100644 index 000000000000..dab3a431e412 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/SparkE2ETransactionalBulkWriterITest.scala @@ -0,0 +1,470 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos.spark + +import com.azure.cosmos.implementation.{TestConfigurations, Utils} +import com.azure.cosmos.models.{PartitionKey, PartitionKeyBuilder} +import com.fasterxml.jackson.databind.node.ObjectNode +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{Row, SaveMode} + +import java.util.UUID +// scalastyle:off underscore.import +import scala.collection.JavaConverters._ +// scalastyle:on underscore.import + +//scalastyle:off multiple.string.literals +//scalastyle:off magic.number +class SparkE2ETransactionalBulkWriterITest extends IntegrationSpec + with Spark + with AutoCleanableCosmosContainersWithPkAsPartitionKey { + + // These tests require the Cosmos Emulator running locally + // Run with: -Dspark-e2e_3-5_2-12=true + + private def getBaseWriteConfig(container: String): Map[String, String] = Map( + "spark.cosmos.accountEndpoint" -> TestConfigurations.HOST, + "spark.cosmos.accountKey" -> TestConfigurations.MASTER_KEY, + "spark.cosmos.database" -> cosmosDatabase, + "spark.cosmos.container" -> container, + "spark.cosmos.write.bulk.enabled" -> "true", + "spark.cosmos.write.bulk.transactional" -> "true" + ) + + + private val simpleSchema = StructType(Seq( + StructField("id", StringType, nullable = false), + StructField("pk", StringType, nullable = false), + StructField("name", StringType, nullable = false) + )) + + // ===================================================== + // Happy Path — Each Write Strategy + // ===================================================== + + "transactional write with ItemOverwrite" should "upsert documents atomically" in { + val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainersWithPkAsPartitionKey) + val partitionKeyValue = UUID.randomUUID().toString + val writeConfig = getBaseWriteConfig(cosmosContainersWithPkAsPartitionKey) + + ("spark.cosmos.write.strategy" -> "ItemOverwrite") + + val batchOperations = Seq( + Row(s"upsert-1-${UUID.randomUUID()}", partitionKeyValue, "Alice"), + Row(s"upsert-2-${UUID.randomUUID()}", partitionKeyValue, "Bob"), + Row(s"upsert-3-${UUID.randomUUID()}", partitionKeyValue, "Charlie") + ) + val operationsDf = spark.createDataFrame(batchOperations.asJava, simpleSchema) + + operationsDf.write + .format("cosmos.oltp") + .options(writeConfig) + .mode(SaveMode.Append) + .save() + + // Verify all 3 docs exist + val queryResult = container + .queryItems(s"SELECT * FROM c WHERE c.pk = '$partitionKeyValue'", classOf[ObjectNode]) + .collectList() + .block() + queryResult.size() shouldBe 3 + } + + "transactional write with ItemAppend" should "create new documents" in { + val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainersWithPkAsPartitionKey) + val partitionKeyValue = UUID.randomUUID().toString + val writeConfig = getBaseWriteConfig(cosmosContainersWithPkAsPartitionKey) + + ("spark.cosmos.write.strategy" -> "ItemAppend") + + val item1Id = s"append-1-${UUID.randomUUID()}" + val item2Id = s"append-2-${UUID.randomUUID()}" + + val batchOperations = Seq( + Row(item1Id, partitionKeyValue, "Doc1"), + Row(item2Id, partitionKeyValue, "Doc2") + ) + val operationsDf = spark.createDataFrame(batchOperations.asJava, simpleSchema) + + operationsDf.write + .format("cosmos.oltp") + .options(writeConfig) + .mode(SaveMode.Append) + .save() + + // Verify both docs were created + val item1 = container.readItem(item1Id, new PartitionKey(partitionKeyValue), classOf[ObjectNode]).block() + item1 should not be null + item1.getItem.get("name").asText() shouldEqual "Doc1" + + val item2 = container.readItem(item2Id, new PartitionKey(partitionKeyValue), classOf[ObjectNode]).block() + item2 should not be null + item2.getItem.get("name").asText() shouldEqual "Doc2" + } + + "transactional write with ItemDelete" should "delete existing documents" in { + val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainersWithPkAsPartitionKey) + val partitionKeyValue = UUID.randomUUID().toString + + // Seed 3 documents to delete + val ids = (1 to 3).map { i => + val id = s"delete-$i-${UUID.randomUUID()}" + val seedNode = Utils.getSimpleObjectMapper.createObjectNode() + seedNode.put("id", id) + seedNode.put("pk", partitionKeyValue) + seedNode.put("name", s"ToDelete-$i") + container.createItem(seedNode, new PartitionKey(partitionKeyValue), null).block() + id + } + + // Verify they exist + val beforeCount = container + .queryItems(s"SELECT * FROM c WHERE c.pk = '$partitionKeyValue'", classOf[ObjectNode]) + .collectList() + .block() + beforeCount.size() shouldBe 3 + + // Build delete DataFrame (needs id and pk columns) + val deleteRows = ids.map(id => Row(id, partitionKeyValue, "placeholder")) + val deleteDf = spark.createDataFrame(deleteRows.asJava, simpleSchema) + + val writeConfig = getBaseWriteConfig(cosmosContainersWithPkAsPartitionKey) + + ("spark.cosmos.write.strategy" -> "ItemDelete") + + deleteDf.write + .format("cosmos.oltp") + .options(writeConfig) + .mode(SaveMode.Append) + .save() + + // Verify all 3 docs were deleted + val afterCount = container + .queryItems(s"SELECT * FROM c WHERE c.pk = '$partitionKeyValue'", classOf[ObjectNode]) + .collectList() + .block() + afterCount.size() shouldBe 0 + } + + "transactional write with ItemOverwriteIfNotModified" should "create new docs when no ETag present" in { + val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainersWithPkAsPartitionKey) + val partitionKeyValue = UUID.randomUUID().toString + val writeConfig = getBaseWriteConfig(cosmosContainersWithPkAsPartitionKey) + + ("spark.cosmos.write.strategy" -> "ItemOverwriteIfNotModified") + + val item1Id = s"conditional-${UUID.randomUUID()}" + + // No ETag → falls back to CREATE + val batchOperations = Seq( + Row(item1Id, partitionKeyValue, "ConditionalDoc") + ) + val operationsDf = spark.createDataFrame(batchOperations.asJava, simpleSchema) + + operationsDf.write + .format("cosmos.oltp") + .options(writeConfig) + .mode(SaveMode.Append) + .save() + + // Verify item was created + val item = container.readItem(item1Id, new PartitionKey(partitionKeyValue), classOf[ObjectNode]).block() + item should not be null + item.getItem.get("name").asText() shouldEqual "ConditionalDoc" + } + + // ===================================================== + // Error / Atomicity Tests + // ===================================================== + + "transactional write with ItemAppend on existing docs" should "FAIL entire batch on first attempt" in { + val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainersWithPkAsPartitionKey) + val partitionKeyValue = UUID.randomUUID().toString + + // Seed one document + val existingId = s"existing-${UUID.randomUUID()}" + val seedNode = Utils.getSimpleObjectMapper.createObjectNode() + seedNode.put("id", existingId) + seedNode.put("pk", partitionKeyValue) + seedNode.put("name", "AlreadyExists") + container.createItem(seedNode, new PartitionKey(partitionKeyValue), null).block() + + // Try to create a batch with the existing doc + a new doc (same PK) + val newId = s"new-${UUID.randomUUID()}" + val batchOperations = Seq( + Row(existingId, partitionKeyValue, "Duplicate"), // 409 — already exists + Row(newId, partitionKeyValue, "NewDoc") + ) + val operationsDf = spark.createDataFrame(batchOperations.asJava, simpleSchema) + + val writeConfig = getBaseWriteConfig(cosmosContainersWithPkAsPartitionKey) + + ("spark.cosmos.write.strategy" -> "ItemAppend") + + // Should fail because existingId already exists -> 409 -> batch rolls back + intercept[Exception] { + operationsDf.write + .format("cosmos.oltp") + .options(writeConfig) + .mode(SaveMode.Append) + .save() + } + + // Verify the new doc was NOT created (rolled back) + val queryResult = container + .queryItems(s"SELECT * FROM c WHERE c.id = '$newId' AND c.pk = '$partitionKeyValue'", classOf[ObjectNode]) + .collectList() + .block() + queryResult.size() shouldBe 0 + + // Verify the original doc is unchanged + val originalDoc = container.readItem(existingId, new PartitionKey(partitionKeyValue), classOf[ObjectNode]).block() + originalDoc.getItem.get("name").asText() shouldEqual "AlreadyExists" + } + + "transactional batch atomicity" should "roll back all operations on failure" in { + val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainersWithPkAsPartitionKey) + val partitionKeyValue = UUID.randomUUID().toString + + // Seed doc B only + val idB = s"docB-${UUID.randomUUID()}" + val seedNode = Utils.getSimpleObjectMapper.createObjectNode() + seedNode.put("id", idB) + seedNode.put("pk", partitionKeyValue) + seedNode.put("name", "DocB") + container.createItem(seedNode, new PartitionKey(partitionKeyValue), null).block() + + // Try to delete [A (missing), B (exists)] — A will cause 404 -> entire batch rolls back + val idA = s"docA-${UUID.randomUUID()}" + val deleteRows = Seq( + Row(idA, partitionKeyValue, "phantom"), + Row(idB, partitionKeyValue, "DocB") + ) + val deleteDf = spark.createDataFrame(deleteRows.asJava, simpleSchema) + + val writeConfig = getBaseWriteConfig(cosmosContainersWithPkAsPartitionKey) + + ("spark.cosmos.write.strategy" -> "ItemDelete") + + // Should fail because doc A doesn't exist + intercept[Exception] { + deleteDf.write + .format("cosmos.oltp") + .options(writeConfig) + .mode(SaveMode.Append) + .save() + } + + // Verify doc B was NOT deleted (entire batch rolled back) + val docB = container.readItem(idB, new PartitionKey(partitionKeyValue), classOf[ObjectNode]).block() + docB should not be null + docB.getItem.get("name").asText() shouldEqual "DocB" + } + + "transactional write across multiple partition keys" should "group into separate atomic batches" in { + val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainersWithPkAsPartitionKey) + val pk1 = UUID.randomUUID().toString + val pk2 = UUID.randomUUID().toString + + val writeConfig = getBaseWriteConfig(cosmosContainersWithPkAsPartitionKey) + + ("spark.cosmos.write.strategy" -> "ItemOverwrite") + + // Items with 2 different PKs — should be grouped into 2 separate batches + val batchOperations = Seq( + Row(s"pk1-1-${UUID.randomUUID()}", pk1, "A"), + Row(s"pk1-2-${UUID.randomUUID()}", pk1, "B"), + Row(s"pk2-1-${UUID.randomUUID()}", pk2, "C"), + Row(s"pk2-2-${UUID.randomUUID()}", pk2, "D") + ) + val operationsDf = spark.createDataFrame(batchOperations.asJava, simpleSchema) + + operationsDf.write + .format("cosmos.oltp") + .options(writeConfig) + .mode(SaveMode.Append) + .save() + + // Verify all docs exist in correct partitions + val pk1Count = container + .queryItems(s"SELECT * FROM c WHERE c.pk = '$pk1'", classOf[ObjectNode]) + .collectList() + .block() + pk1Count.size() shouldBe 2 + + val pk2Count = container + .queryItems(s"SELECT * FROM c WHERE c.pk = '$pk2'", classOf[ObjectNode]) + .collectList() + .block() + pk2Count.size() shouldBe 2 + } + + // ===================================================== + // HPK-Specific E2E Tests + // ===================================================== + + "transactional write with HPK ItemOverwrite" should "upsert documents with 2-level partition key" in { + // Create container with hierarchical partition keys + val containerName = s"test-hpk-upsert-${UUID.randomUUID()}" + val containerProperties = new com.azure.cosmos.models.CosmosContainerProperties( + containerName, + new com.azure.cosmos.models.PartitionKeyDefinition() + ) + val paths = new java.util.ArrayList[String]() + paths.add("/tenantId") + paths.add("/userId") + containerProperties.getPartitionKeyDefinition.setPaths(paths) + containerProperties.getPartitionKeyDefinition.setKind(com.azure.cosmos.models.PartitionKind.MULTI_HASH) + containerProperties.getPartitionKeyDefinition.setVersion(com.azure.cosmos.models.PartitionKeyDefinitionVersion.V2) + cosmosClient.getDatabase(cosmosDatabase).createContainer(containerProperties).block() + val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(containerName) + + try { + val writeConfig = Map( + "spark.cosmos.accountEndpoint" -> TestConfigurations.HOST, + "spark.cosmos.accountKey" -> TestConfigurations.MASTER_KEY, + "spark.cosmos.database" -> cosmosDatabase, + "spark.cosmos.container" -> containerName, + "spark.cosmos.write.bulk.enabled" -> "true", + "spark.cosmos.write.bulk.transactional" -> "true", + "spark.cosmos.write.strategy" -> "ItemOverwrite" + ) + + val schema = StructType(Seq( + StructField("id", StringType, nullable = false), + StructField("tenantId", StringType, nullable = false), + StructField("userId", StringType, nullable = false), + StructField("score", IntegerType, nullable = false) + )) + + val batchOperations = Seq( + Row(s"doc-1-${UUID.randomUUID()}", "Contoso", "alice", 100), + Row(s"doc-2-${UUID.randomUUID()}", "Contoso", "alice", 200), + Row(s"doc-3-${UUID.randomUUID()}", "Contoso", "alice", 300) + ) + val operationsDf = spark.createDataFrame(batchOperations.asJava, schema) + + operationsDf.write + .format("cosmos.oltp") + .options(writeConfig) + .mode(SaveMode.Append) + .save() + + // Verify all 3 docs exist via query + val queryResult = container + .queryItems(s"SELECT * FROM c WHERE c.tenantId = 'Contoso' AND c.userId = 'alice'", classOf[ObjectNode]) + .collectList() + .block() + // Should have 3 business docs (marker is actively deleted after success) + queryResult.size() should be >= 3 + } finally { + container.delete().block() + } + } + + "transactional write with HPK batch grouping" should "create separate batches per full HPK value" in { + // This test verifies that the String-based PK comparison correctly + // distinguishes different HPK values that share a common prefix. + val containerName = s"test-hpk-grouping-${UUID.randomUUID()}" + val containerProperties = new com.azure.cosmos.models.CosmosContainerProperties( + containerName, + new com.azure.cosmos.models.PartitionKeyDefinition() + ) + val paths = new java.util.ArrayList[String]() + paths.add("/tenantId") + paths.add("/userId") + containerProperties.getPartitionKeyDefinition.setPaths(paths) + containerProperties.getPartitionKeyDefinition.setKind(com.azure.cosmos.models.PartitionKind.MULTI_HASH) + containerProperties.getPartitionKeyDefinition.setVersion(com.azure.cosmos.models.PartitionKeyDefinitionVersion.V2) + cosmosClient.getDatabase(cosmosDatabase).createContainer(containerProperties).block() + val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(containerName) + + try { + val writeConfig = Map( + "spark.cosmos.accountEndpoint" -> TestConfigurations.HOST, + "spark.cosmos.accountKey" -> TestConfigurations.MASTER_KEY, + "spark.cosmos.database" -> cosmosDatabase, + "spark.cosmos.container" -> containerName, + "spark.cosmos.write.bulk.enabled" -> "true", + "spark.cosmos.write.bulk.transactional" -> "true", + "spark.cosmos.write.strategy" -> "ItemOverwrite" + ) + + val schema = StructType(Seq( + StructField("id", StringType, nullable = false), + StructField("tenantId", StringType, nullable = false), + StructField("userId", StringType, nullable = false), + StructField("name", StringType, nullable = false) + )) + + // Two different HPK values — must produce separate batches + val batchOperations = Seq( + Row(s"alice-1-${UUID.randomUUID()}", "Contoso", "alice", "A1"), + Row(s"alice-2-${UUID.randomUUID()}", "Contoso", "alice", "A2"), + Row(s"bob-1-${UUID.randomUUID()}", "Contoso", "bob", "B1"), + Row(s"bob-2-${UUID.randomUUID()}", "Contoso", "bob", "B2") + ) + val operationsDf = spark.createDataFrame(batchOperations.asJava, schema) + + operationsDf.write + .format("cosmos.oltp") + .options(writeConfig) + .mode(SaveMode.Append) + .save() + + // Verify docs in alice's partition + val aliceCount = container + .queryItems(s"SELECT * FROM c WHERE c.tenantId = 'Contoso' AND c.userId = 'alice'", classOf[ObjectNode]) + .collectList() + .block() + aliceCount.size() should be >= 2 + + // Verify docs in bob's partition + val bobCount = container + .queryItems(s"SELECT * FROM c WHERE c.tenantId = 'Contoso' AND c.userId = 'bob'", classOf[ObjectNode]) + .collectList() + .block() + bobCount.size() should be >= 2 + } finally { + container.delete().block() + } + } + + // ===================================================== + // Marker Cleanup Verification + // ===================================================== + + "transactional write marker cleanup" should "not leave marker documents after successful write" in { + val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainersWithPkAsPartitionKey) + val partitionKeyValue = UUID.randomUUID().toString + val writeConfig = getBaseWriteConfig(cosmosContainersWithPkAsPartitionKey) + + ("spark.cosmos.write.strategy" -> "ItemOverwrite") + + val batchOperations = Seq( + Row(s"marker-test-1-${UUID.randomUUID()}", partitionKeyValue, "Doc1"), + Row(s"marker-test-2-${UUID.randomUUID()}", partitionKeyValue, "Doc2") + ) + val operationsDf = spark.createDataFrame(batchOperations.asJava, simpleSchema) + + operationsDf.write + .format("cosmos.oltp") + .options(writeConfig) + .mode(SaveMode.Append) + .save() + + // Small delay to allow async marker deletion to complete + Thread.sleep(2000) + + // Query all docs for this partition key + val allDocs = container + .queryItems(s"SELECT * FROM c WHERE c.pk = '$partitionKeyValue'", classOf[ObjectNode]) + .collectList() + .block() + + // Should have only business docs — marker should be actively deleted + val markerDocs = allDocs.asScala.filter(doc => + doc.has("id") && doc.get("id").asText().startsWith("__tbw:")) + + markerDocs.size shouldBe 0 // marker was actively deleted after success + // Business docs should exist + val businessDocs = allDocs.asScala.filter(doc => + doc.has("id") && !doc.get("id").asText().startsWith("__tbw:")) + businessDocs.size shouldBe 2 + } +} +//scalastyle:on magic.number +//scalastyle:on multiple.string.literals + diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/TransactionalBatchITest.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/TransactionalBatchITest.scala index 0905934bebba..9371a2700da5 100644 --- a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/TransactionalBatchITest.scala +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/TransactionalBatchITest.scala @@ -119,7 +119,7 @@ class TransactionalBatchITest extends IntegrationSpec s"SELECT VALUE COUNT(1) FROM c WHERE c.pk = '$partitionKeyValue'", classOf[Long] ).collectList().block() - + if (countList.isEmpty) { 0 } else { @@ -181,9 +181,10 @@ class TransactionalBatchITest extends IntegrationSpec queryResult.size() shouldBe 0 } - it should "reject unsupported write strategies" in { + it should "accept ItemAppend write strategy" in { val cosmosEndpoint = TestConfigurations.HOST val cosmosMasterKey = TestConfigurations.MASTER_KEY + val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainersWithPkAsPartitionKey) val partitionKeyValue = UUID.randomUUID().toString val item1Id = s"test-item1-${UUID.randomUUID()}" @@ -200,59 +201,112 @@ class TransactionalBatchITest extends IntegrationSpec val operationsDf = spark.createDataFrame(batchOperations.asJava, schema) - // Test ItemAppend (create) - should fail - val appendException = intercept[Exception] { - operationsDf.write - .format("cosmos.oltp") - .option("spark.cosmos.accountEndpoint", cosmosEndpoint) - .option("spark.cosmos.accountKey", cosmosMasterKey) - .option("spark.cosmos.database", cosmosDatabase) - .option("spark.cosmos.container", cosmosContainersWithPkAsPartitionKey) - .option("spark.cosmos.write.bulk.transactional", "true") - .option("spark.cosmos.write.bulk.enabled", "true") - .option("spark.cosmos.write.strategy", "ItemAppend") - .mode(SaveMode.Append) - .save() - } - val appendRootCause = getRootCause(appendException) - assert(appendRootCause.getMessage.contains("Transactional batches only support ItemOverwrite"), - s"Expected ItemAppend rejection, got: ${appendRootCause.getMessage}") + // ItemAppend should now be accepted for transactional batches + operationsDf.write + .format("cosmos.oltp") + .option("spark.cosmos.accountEndpoint", cosmosEndpoint) + .option("spark.cosmos.accountKey", cosmosMasterKey) + .option("spark.cosmos.database", cosmosDatabase) + .option("spark.cosmos.container", cosmosContainersWithPkAsPartitionKey) + .option("spark.cosmos.write.bulk.transactional", "true") + .option("spark.cosmos.write.bulk.enabled", "true") + .option("spark.cosmos.write.strategy", "ItemAppend") + .mode(SaveMode.Append) + .save() - // Test ItemDelete - should fail - val deleteException = intercept[Exception] { - operationsDf.write - .format("cosmos.oltp") - .option("spark.cosmos.accountEndpoint", cosmosEndpoint) - .option("spark.cosmos.accountKey", cosmosMasterKey) - .option("spark.cosmos.database", cosmosDatabase) - .option("spark.cosmos.container", cosmosContainersWithPkAsPartitionKey) - .option("spark.cosmos.write.bulk.transactional", "true") - .option("spark.cosmos.write.bulk.enabled", "true") - .option("spark.cosmos.write.strategy", "ItemDelete") - .mode(SaveMode.Append) - .save() - } - val deleteRootCause = getRootCause(deleteException) - assert(deleteRootCause.getMessage.contains("Transactional batches only support ItemOverwrite"), - s"Expected ItemDelete rejection, got: ${deleteRootCause.getMessage}") + // Verify item was created + val item1 = container.readItem(item1Id, new PartitionKey(partitionKeyValue), classOf[ObjectNode]).block() + item1 should not be null + item1.getItem.get("name").asText() shouldEqual "TestItem" + } - // Test ItemOverwriteIfNotModified - should fail - val replaceException = intercept[Exception] { - operationsDf.write - .format("cosmos.oltp") - .option("spark.cosmos.accountEndpoint", cosmosEndpoint) - .option("spark.cosmos.accountKey", cosmosMasterKey) - .option("spark.cosmos.database", cosmosDatabase) - .option("spark.cosmos.container", cosmosContainersWithPkAsPartitionKey) - .option("spark.cosmos.write.bulk.transactional", "true") - .option("spark.cosmos.write.bulk.enabled", "true") - .option("spark.cosmos.write.strategy", "ItemOverwriteIfNotModified") - .mode(SaveMode.Append) - .save() - } - val replaceRootCause = getRootCause(replaceException) - assert(replaceRootCause.getMessage.contains("Transactional batches only support ItemOverwrite"), - s"Expected ItemOverwriteIfNotModified rejection, got: ${replaceRootCause.getMessage}") + it should "accept ItemDelete write strategy" in { + val cosmosEndpoint = TestConfigurations.HOST + val cosmosMasterKey = TestConfigurations.MASTER_KEY + val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainersWithPkAsPartitionKey) + val partitionKeyValue = UUID.randomUUID().toString + val item1Id = s"test-delete-${UUID.randomUUID()}" + + // First, seed a document to delete + val seedNode = Utils.getSimpleObjectMapper.createObjectNode() + seedNode.put("id", item1Id) + seedNode.put("pk", partitionKeyValue) + seedNode.put("name", "ToBeDeleted") + container.createItem(seedNode, new PartitionKey(partitionKeyValue), null).block() + + // Verify it exists + val seedItem = container.readItem(item1Id, new PartitionKey(partitionKeyValue), classOf[ObjectNode]).block() + seedItem should not be null + + val schema = StructType(Seq( + StructField("id", StringType, nullable = false), + StructField("pk", StringType, nullable = false), + StructField("name", StringType, nullable = false) + )) + + val batchOperations = Seq( + Row(item1Id, partitionKeyValue, "ToBeDeleted") + ) + + val operationsDf = spark.createDataFrame(batchOperations.asJava, schema) + + // ItemDelete should now be accepted for transactional batches + operationsDf.write + .format("cosmos.oltp") + .option("spark.cosmos.accountEndpoint", cosmosEndpoint) + .option("spark.cosmos.accountKey", cosmosMasterKey) + .option("spark.cosmos.database", cosmosDatabase) + .option("spark.cosmos.container", cosmosContainersWithPkAsPartitionKey) + .option("spark.cosmos.write.bulk.transactional", "true") + .option("spark.cosmos.write.bulk.enabled", "true") + .option("spark.cosmos.write.strategy", "ItemDelete") + .mode(SaveMode.Append) + .save() + + // Verify item was deleted + val queryResult = container + .queryItems(s"SELECT * FROM c WHERE c.id = '$item1Id' AND c.pk = '$partitionKeyValue'", classOf[ObjectNode]) + .collectList() + .block() + queryResult.size() shouldBe 0 + } + + it should "accept ItemOverwriteIfNotModified write strategy" in { + val cosmosEndpoint = TestConfigurations.HOST + val cosmosMasterKey = TestConfigurations.MASTER_KEY + val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainersWithPkAsPartitionKey) + val partitionKeyValue = UUID.randomUUID().toString + val item1Id = s"test-replace-${UUID.randomUUID()}" + + val schema = StructType(Seq( + StructField("id", StringType, nullable = false), + StructField("pk", StringType, nullable = false), + StructField("name", StringType, nullable = false) + )) + + // ItemOverwriteIfNotModified without ETag falls back to CREATE + val batchOperations = Seq( + Row(item1Id, partitionKeyValue, "NewItem") + ) + + val operationsDf = spark.createDataFrame(batchOperations.asJava, schema) + + operationsDf.write + .format("cosmos.oltp") + .option("spark.cosmos.accountEndpoint", cosmosEndpoint) + .option("spark.cosmos.accountKey", cosmosMasterKey) + .option("spark.cosmos.database", cosmosDatabase) + .option("spark.cosmos.container", cosmosContainersWithPkAsPartitionKey) + .option("spark.cosmos.write.bulk.transactional", "true") + .option("spark.cosmos.write.bulk.enabled", "true") + .option("spark.cosmos.write.strategy", "ItemOverwriteIfNotModified") + .mode(SaveMode.Append) + .save() + + // Verify item was created (ItemOverwriteIfNotModified without ETag creates) + val item1 = container.readItem(item1Id, new PartitionKey(partitionKeyValue), classOf[ObjectNode]).block() + item1 should not be null + item1.getItem.get("name").asText() shouldEqual "NewItem" } it should "support simplified schema with default upsert operation" in { @@ -699,7 +753,7 @@ class TransactionalBatchITest extends IntegrationSpec val cosmosEndpoint = TestConfigurations.HOST val cosmosMasterKey = TestConfigurations.MASTER_KEY val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainersWithPkAsPartitionKey) - + // Create operations for multiple partition keys intentionally in random order val pk1 = UUID.randomUUID().toString val pk2 = UUID.randomUUID().toString @@ -753,16 +807,16 @@ class TransactionalBatchITest extends IntegrationSpec val cosmosEndpoint = TestConfigurations.HOST val cosmosMasterKey = TestConfigurations.MASTER_KEY val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainersWithPkAsPartitionKey) - + // Use a small maxPendingOperations value to force batch-level limiting // With maxPendingOperations=50, maxPendingBatches = 50/50 = 1 // This means only 1 batch should be in-flight at a time val maxPendingOperations = 50 - + // Create 200 operations across 4 batches (50 operations per partition key = 1 batch each) // This will test that the semaphore properly limits concurrent batches val partitionKeys = (1 to 4).map(_ => UUID.randomUUID().toString) - + val schema = StructType(Seq( StructField("id", StringType, nullable = false), StructField("pk", StringType, nullable = false), @@ -797,11 +851,11 @@ class TransactionalBatchITest extends IntegrationSpec .queryItems(s"SELECT VALUE COUNT(1) FROM c WHERE c.pk = '$pk'", classOf[Long]) .collectList() .block() - + val count = if (queryResult.isEmpty) 0L else queryResult.get(0) assert(count == 50, s"Expected 50 items for partition key $pk, but found $count") } - + // If we get here without deadlock or timeout, batch-level backpressure is working // The test verifies: // 1. Operations complete successfully even with tight batch limit @@ -813,22 +867,22 @@ class TransactionalBatchITest extends IntegrationSpec val cosmosEndpoint = TestConfigurations.HOST val cosmosMasterKey = TestConfigurations.MASTER_KEY val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainersWithPkAsPartitionKey) - + // Create two sets of operations for the same partition key // This tests that multiple successful batches can be written to the same partition key val partitionKeyValue = UUID.randomUUID().toString - + val schema = StructType(Seq( StructField("id", StringType, nullable = false), StructField("pk", StringType, nullable = false), StructField("counter", IntegerType, nullable = false) )) - + // First, write initial items that will later be updated transactionally val initialItems = (1 to 10).map { i => Row(s"item-$i", partitionKeyValue, 0) } - + val initialDf = spark.createDataFrame(initialItems.asJava, schema) initialDf.write .format("cosmos.oltp") @@ -840,27 +894,27 @@ class TransactionalBatchITest extends IntegrationSpec .option("spark.cosmos.write.bulk.enabled", "true") .mode(SaveMode.Append) .save() - + // Now update items 1-5 to counter=1, then items 6-10 to counter=2 // Both are atomic batches for the same partition key - // If retries from the first batch interleave with the second batch, + // If retries from the first batch interleave with the second batch, // atomicity would be violated val batch1 = (1 to 5).map { i => Row(s"item-$i", partitionKeyValue, 1) } - + val batch2 = (6 to 10).map { i => Row(s"item-$i", partitionKeyValue, 2) } - + val allUpdates = batch1 ++ batch2 val updatesDf = spark.createDataFrame(allUpdates.asJava, schema) - + // Delete existing items first since Overwrite mode is not supported in transactional mode (1 to 10).foreach { i => container.deleteItem(s"item-$i", new PartitionKey(partitionKeyValue), null).block() } - + updatesDf.write .format("cosmos.oltp") .option("spark.cosmos.accountEndpoint", cosmosEndpoint) @@ -872,7 +926,7 @@ class TransactionalBatchITest extends IntegrationSpec .option("spark.cosmos.write.bulk.maxPendingOperations", "10") // Force separate batches .mode(SaveMode.Append) .save() - + // Verify the final state: all items should have their expected counter values // If interleaving occurred, some updates might have been lost or inconsistent (1 to 5).foreach { i => @@ -881,37 +935,37 @@ class TransactionalBatchITest extends IntegrationSpec new PartitionKey(partitionKeyValue), classOf[ObjectNode] ).block() - + assert(item != null, s"Item item-$i should exist") - assert(item.getItem.get("counter").asInt() == 1, + assert(item.getItem.get("counter").asInt() == 1, s"Item item-$i should have counter=1, but got ${item.getItem.get("counter").asInt()}") } - + (6 to 10).foreach { i => val item = container.readItem( s"item-$i", new PartitionKey(partitionKeyValue), classOf[ObjectNode] ).block() - + assert(item != null, s"Item item-$i should exist") assert(item.getItem.get("counter").asInt() == 2, s"Item item-$i should have counter=2, but got ${item.getItem.get("counter").asInt()}") } } - + it should "handle batch-level retries for retriable errors without interleaving" in { val cosmosEndpoint = TestConfigurations.HOST val cosmosMasterKey = TestConfigurations.MASTER_KEY - + val partitionKeyValue = UUID.randomUUID().toString - + val schema = StructType(Seq( StructField("id", StringType, nullable = false), StructField("pk", StringType, nullable = false), StructField("counter", IntegerType, nullable = false) )) - + // Configuration for Spark connector - must match exactly for cache lookup val cfg = Map( "spark.cosmos.accountEndpoint" -> cosmosEndpoint, @@ -919,13 +973,13 @@ class TransactionalBatchITest extends IntegrationSpec "spark.cosmos.database" -> cosmosDatabase, "spark.cosmos.container" -> cosmosContainersWithPkAsPartitionKey ) - + // Create initial items with counter = 0 // This FIRST write ensures Spark creates the client and caches it val initialItems = (1 to 10).map { i => Row(s"item-$i", partitionKeyValue, 0) } - + val initialDf = spark.createDataFrame(initialItems.asJava, schema) initialDf.write .format("cosmos.oltp") @@ -934,15 +988,15 @@ class TransactionalBatchITest extends IntegrationSpec .option("spark.cosmos.write.bulk.enabled", "true") .mode(SaveMode.Append) .save() - + // NOW get the actual client that Spark created and cached val clientFromCache = udf.CosmosAsyncClientCache .getCosmosClientFromCache(cfg) .getClient .asInstanceOf[CosmosAsyncClient] - + val container = clientFromCache.getDatabase(cosmosDatabase).getContainer(cosmosContainersWithPkAsPartitionKey) - + // Configure fault injection to inject retriable 429 TOO_MANY_REQUEST errors on BATCH_ITEM operations // 429 errors are retriable and should trigger batch-level retry without aborting the job // This tests that transactional batches handle retries at the batch level @@ -960,18 +1014,18 @@ class TransactionalBatchITest extends IntegrationSpec ) .duration(Duration.ofMinutes(5)) .build() - + // Configure the fault injection rule on the container CosmosFaultInjectionHelper.configureFaultInjectionRules(container, java.util.Collections.singletonList(faultInjectionRule)).block() - + try { // Now write just ONE more item using NEW ID to avoid conflicts with the initial write // This is the absolute simplest test case to verify batch-level retry handling // With maxPendingOperations=1, this single item becomes its own batch val singleItem = Seq(Row("item-11", partitionKeyValue, 1)) - + val updatesDf = spark.createDataFrame(singleItem.asJava, schema) - + updatesDf.write .format("cosmos.oltp") .options(cfg) @@ -980,16 +1034,16 @@ class TransactionalBatchITest extends IntegrationSpec .option("spark.cosmos.write.bulk.maxPendingOperations", "1") // Ensure minimal batch size .mode(SaveMode.Append) .save() - + // Verify that fault injection triggered and was retried successfully // This confirms that retriable errors (429) trigger batch-level retries val hitCount = faultInjectionRule.getHitCount assert(hitCount > 0, s"Fault injection should have tracked BATCH_ITEM operations, but hit count was $hitCount") - + // Verify the single item was written correctly despite the injected 429 error and retry val item11 = container.readItem("item-11", new PartitionKey(partitionKeyValue), classOf[ObjectNode]).block() assert(item11 != null, "Item item-11 should exist") - assert(item11.getItem.get("counter").asInt() == 1, + assert(item11.getItem.get("counter").asInt() == 1, s"Item item-11 should have counter=1, but got ${item11.getItem.get("counter").asInt()}") } finally { // Clean up: disable the fault injection rule diff --git a/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/TransactionalBulkWriterSpec.scala b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/TransactionalBulkWriterSpec.scala new file mode 100644 index 000000000000..a89e0d271709 --- /dev/null +++ b/sdk/cosmos/azure-cosmos-spark_3/src/test/scala/com/azure/cosmos/spark/TransactionalBulkWriterSpec.scala @@ -0,0 +1,895 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.cosmos.spark + +import com.azure.cosmos.models.{ + CosmosBatch, + CosmosBatchItemRequestOptions, + CosmosBatchResponse, + CosmosBulkOperations, + ModelBridgeInternal, + PartitionKey, + PartitionKeyBuilder, + PartitionKeyDefinition +} +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.databind.node.ObjectNode + +import java.time.Duration +import java.util +import java.util.concurrent.ConcurrentHashMap +// scalastyle:off underscore.import +import scala.collection.JavaConverters._ +// scalastyle:on underscore.import + +//scalastyle:off multiple.string.literals +//scalastyle:off magic.number +//scalastyle:off null +class TransactionalBulkWriterSpec extends UnitSpec { + + private val objectMapper = new ObjectMapper() + + private def createObjectNode(id: String, pk: String, eTag: Option[String] = None): ObjectNode = { + val node = objectMapper.createObjectNode() + node.put("id", id) + node.put("pk", pk) + eTag.foreach(e => node.put("_etag", e)) + node + } + + private def createMockBatchResponse( + statusCode: Int, + subStatusCode: Int, + operationResults: List[(Int, Int)] // (statusCode, subStatusCode) per operation + ): CosmosBatchResponse = { + val response = ModelBridgeInternal.createCosmosBatchResponse( + statusCode, + subStatusCode, + null, // errorMessage + new util.HashMap[String, String](), + null // cosmosDiagnostics + ) + + val pk = new PartitionKey("test-pk") + val results = operationResults.map { case (opStatusCode, opSubStatusCode) => + val dummyOperation = CosmosBulkOperations.getUpsertItemOperation( + createObjectNode("dummy", "test-pk"), pk) + ModelBridgeInternal.createCosmosBatchResult( + null, // eTag + 1.0, // requestCharge + null, // resourceObject + opStatusCode, + Duration.ZERO, + opSubStatusCode, + dummyOperation + ) + }.asJava + + ModelBridgeInternal.addCosmosBatchResultInResponse(response, results) + response + } + + // ===================================================== + // Recovery for Delete Operations + // ===================================================== + + "recovery path" should "handle delete operations without NPE (Issue 3 fix)" in { + // Delete operations have no item body — getItem returns null + // The fix uses originalItems (TransactionalBulkItem) instead of batch.getOperations + val pk = new PartitionKey("user-A") + val batch = CosmosBatch.createCosmosBatch(pk) + batch.deleteItemOperation("doc1") + + val operations = batch.getOperations + operations.size() should be(1) + + // Verify getItem returns null for delete — this is the root cause of Issue 3 + val item = operations.get(0).getItem[ObjectNode] + item should be(null) + + // Verify that wrapping as upsert (the fix) preserves the objectNode + val objectNode = createObjectNode("doc1", "user-A") + val wrappedOp = CosmosBulkOperations.getUpsertItemOperation(objectNode, pk) + wrappedOp.getItem[ObjectNode] should not be null + wrappedOp.getItem[ObjectNode].get("id").asText() should be("doc1") + wrappedOp.getPartitionKeyValue should be(pk) + } + + // ===================================================== + // TransactionalBulkItem Field Extraction Tests + // ===================================================== + + "getId pattern" should "extract id from ObjectNode" in { + val objectNode = createObjectNode("doc-123", "user-A") + + val idField = objectNode.get(CosmosConstants.Properties.Id) + idField should not be null + idField.isTextual should be(true) + idField.textValue() should be("doc-123") + } + + "getETag pattern" should "extract eTag from ObjectNode when present" in { + val objectNode = createObjectNode("doc-456", "user-B", Some("etag-abc")) + + val eTagField = objectNode.get(CosmosConstants.Properties.ETag) + eTagField should not be null + eTagField.isTextual should be(true) + eTagField.textValue() should be("etag-abc") + } + + it should "return null when eTag is missing" in { + val objectNode = createObjectNode("doc-789", "user-C") + + val eTagField = objectNode.get(CosmosConstants.Properties.ETag) + eTagField should be(null) + } + + // ===================================================== + // CosmosBatch Strategy Mapping Tests + // ===================================================== + + "CosmosBatch strategy mapping" should "map ItemOverwrite to upsertItemOperation" in { + val pk = new PartitionKey("user-A") + val batch = CosmosBatch.createCosmosBatch(pk) + val objectNode = createObjectNode("doc1", "user-A") + + batch.upsertItemOperation(objectNode) + + batch.getOperations.size() should be(1) + batch.getOperations.get(0).getOperationType.toString should be("UPSERT") + } + + it should "map ItemAppend to createItemOperation" in { + val pk = new PartitionKey("user-A") + val batch = CosmosBatch.createCosmosBatch(pk) + val objectNode = createObjectNode("doc1", "user-A") + + batch.createItemOperation(objectNode) + + batch.getOperations.size() should be(1) + batch.getOperations.get(0).getOperationType.toString should be("CREATE") + } + + it should "map ItemDelete to deleteItemOperation with itemId only" in { + val pk = new PartitionKey("user-A") + val batch = CosmosBatch.createCosmosBatch(pk) + + batch.deleteItemOperation("doc1") + + batch.getOperations.size() should be(1) + batch.getOperations.get(0).getOperationType.toString should be("DELETE") + batch.getOperations.get(0).getId should be("doc1") + batch.getOperations.get(0).getItem[ObjectNode] should be(null) + } + + it should "map ItemDeleteIfNotModified to deleteItemOperation with ETag" in { + val pk = new PartitionKey("user-A") + val batch = CosmosBatch.createCosmosBatch(pk) + val requestOptions = new CosmosBatchItemRequestOptions() + requestOptions.setIfMatchETag("etag-123") + + batch.deleteItemOperation("doc1", requestOptions) + + batch.getOperations.size() should be(1) + batch.getOperations.get(0).getOperationType.toString should be("DELETE") + batch.getOperations.get(0).getId should be("doc1") + } + + it should "map ItemOverwriteIfNotModified with ETag to replaceItemOperation" in { + val pk = new PartitionKey("user-A") + val batch = CosmosBatch.createCosmosBatch(pk) + val objectNode = createObjectNode("doc1", "user-A", Some("etag-abc")) + val requestOptions = new CosmosBatchItemRequestOptions() + requestOptions.setIfMatchETag("etag-abc") + + batch.replaceItemOperation("doc1", objectNode, requestOptions) + + batch.getOperations.size() should be(1) + batch.getOperations.get(0).getOperationType.toString should be("REPLACE") + batch.getOperations.get(0).getId should be("doc1") + } + + it should "map ItemOverwriteIfNotModified without ETag to createItemOperation" in { + val pk = new PartitionKey("user-A") + val batch = CosmosBatch.createCosmosBatch(pk) + val objectNode = createObjectNode("doc1", "user-A") // no ETag + + batch.createItemOperation(objectNode) + + batch.getOperations.size() should be(1) + batch.getOperations.get(0).getOperationType.toString should be("CREATE") + } + + it should "preserve operation order in batch" in { + val pk = new PartitionKey("user-A") + val batch = CosmosBatch.createCosmosBatch(pk) + + batch.createItemOperation(createObjectNode("doc1", "user-A")) + batch.upsertItemOperation(createObjectNode("doc2", "user-A")) + batch.deleteItemOperation("doc3") + + val ops = batch.getOperations + ops.size() should be(3) + ops.get(0).getOperationType.toString should be("CREATE") + ops.get(1).getOperationType.toString should be("UPSERT") + ops.get(2).getOperationType.toString should be("DELETE") + } + + // ===================================================== + // shouldIgnore Status Code Tests + // (Tests the Exceptions helper methods used by shouldIgnore) + // ===================================================== + + "shouldIgnore for ItemAppend" should "ignore 409 Conflict (item already exists)" in { + Exceptions.isResourceExistsException(409) should be(true) + } + + it should "not ignore other status codes" in { + Exceptions.isResourceExistsException(404) should be(false) + Exceptions.isResourceExistsException(412) should be(false) + Exceptions.isResourceExistsException(200) should be(false) + } + + "shouldIgnore for ItemDelete" should "ignore 404/0 Not Found" in { + Exceptions.isNotFoundExceptionCore(404, 0) should be(true) + } + + it should "NOT ignore 404/1002 (partition key range gone — transient, not semantic)" in { + // 404/1002 is a transient error (partition moved), must flow through retry, NOT shouldIgnore + Exceptions.isNotFoundExceptionCore(404, 1002) should be(false) + } + + it should "not ignore 409 or 412" in { + Exceptions.isNotFoundExceptionCore(409, 0) should be(false) + Exceptions.isNotFoundExceptionCore(412, 0) should be(false) + } + + "shouldIgnore for ItemDeleteIfNotModified" should "ignore 404/0 but NOT 412" in { + // TransactionalBulkWriter excludes 412 — ambiguous on retry for batch operations + // BulkWriter includes both 404/0 and 412 + Exceptions.isNotFoundExceptionCore(404, 0) should be(true) + // 412 is a valid Precondition Failed code, but should NOT be in TransactionalBulkWriter's shouldIgnore + Exceptions.isPreconditionFailedException(412) should be(true) // helper returns true... + // ...but TransactionalBulkWriter's shouldIgnore does NOT include 412 for this strategy + } + + "shouldIgnore for ItemOverwriteIfNotModified" should "ignore 409 and 404/0 but NOT 412" in { + Exceptions.isResourceExistsException(409) should be(true) + Exceptions.isNotFoundExceptionCore(404, 0) should be(true) + // 412 is excluded from TransactionalBulkWriter's shouldIgnore + Exceptions.isPreconditionFailedException(412) should be(true) // helper returns true... + // ...but TransactionalBulkWriter's shouldIgnore does NOT include 412 for this strategy + } + + "shouldIgnore for ItemPatchIfExists" should "ignore 404/0 Not Found" in { + Exceptions.isNotFoundExceptionCore(404, 0) should be(true) + } + + "shouldIgnore for ItemOverwrite" should "have no ignorable errors" in { + // Upsert always succeeds — there are no semantic errors to ignore + Exceptions.isResourceExistsException(200) should be(false) + Exceptions.isNotFoundExceptionCore(200, 0) should be(false) + Exceptions.isPreconditionFailedException(200) should be(false) + } + + // ===================================================== + // Transient Error Identification Tests + // ===================================================== + + "canBeTransientFailure" should "identify transient status codes" in { + Exceptions.canBeTransientFailure(408, 0) should be(true) // Request Timeout + Exceptions.canBeTransientFailure(410, 0) should be(true) // Gone + Exceptions.canBeTransientFailure(500, 0) should be(true) // Internal Server Error + Exceptions.canBeTransientFailure(503, 0) should be(true) // Service Unavailable + Exceptions.canBeTransientFailure(404, 1002) should be(true) // Partition Key Range Gone + } + + it should "NOT identify semantic errors as transient" in { + Exceptions.canBeTransientFailure(404, 0) should be(false) // Not Found (semantic) + Exceptions.canBeTransientFailure(409, 0) should be(false) // Conflict (semantic) + Exceptions.canBeTransientFailure(412, 0) should be(false) // Precondition Failed + Exceptions.canBeTransientFailure(400, 0) should be(false) // Bad Request + Exceptions.canBeTransientFailure(429, 0) should be(false) // Too Many Requests (SDK handles) + } + + // ===================================================== + // shouldRetry Strategy-Specific Tests + // ===================================================== + + "shouldRetry for ItemOverwrite" should "retry on 404/0 (TTL expiration race)" in { + // BulkWriter and TransactionalBulkWriter both retry upsert on 404/0 + Exceptions.isNotFoundExceptionCore(404, 0) should be(true) + } + + it should "retry on transient failures" in { + Exceptions.canBeTransientFailure(408, 0) should be(true) + Exceptions.canBeTransientFailure(503, 0) should be(true) + } + + "shouldRetry for other strategies" should "NOT retry on 404/0 (semantic error)" in { + // For non-ItemOverwrite strategies, 404/0 is NOT retried — it's a semantic error + // (except for shouldIgnore which is checked separately before shouldRetry) + Exceptions.isNotFoundExceptionCore(404, 0) should be(true) // helper returns true... + // ...but shouldRetry only includes isNotFoundExceptionCore for ItemOverwrite + } + + // ===================================================== + // CosmosBatchResponse / CosmosBatchOperationResult Tests + // (Verifies the infrastructure used by shouldIgnoreOnRetry) + // ===================================================== + + "CosmosBatchResponse" should "be constructable with per-operation results" in { + val response = createMockBatchResponse( + statusCode = 409, + subStatusCode = 0, + operationResults = List((409, 0), (424, 0), (424, 0)) + ) + + response.getStatusCode should be(409) + response.getResults.size() should be(3) + response.getResults.get(0).getStatusCode should be(409) + response.getResults.get(1).getStatusCode should be(424) + response.getResults.get(2).getStatusCode should be(424) + } + + "shouldIgnoreOnRetry first-operation check" should "find first non-424 result at index 0" in { + // Scenario: ItemAppend retry, op[0]=409, op[1]=424, op[2]=424 + // The first non-424 is at index 0 -> shouldIgnoreOnRetry should return true + val response = createMockBatchResponse(409, 0, List((409, 0), (424, 0), (424, 0))) + val results = response.getResults.asScala + + val firstNon424 = results.zipWithIndex.find { case (result, _) => + result.getStatusCode != 424 + } + + firstNon424 should be(defined) + firstNon424.get._2 should be(0) // index 0 + firstNon424.get._1.getStatusCode should be(409) + // For ItemAppend, 409 is ignorable + Exceptions.isResourceExistsException(409) should be(true) + } + + it should "reject when first non-424 result is NOT at index 0" in { + // Scenario: op[0]=424, op[1]=404, op[2]=424 + // The first non-424 is at index 1 -> shouldIgnoreOnRetry should return false + val response = createMockBatchResponse(404, 0, List((424, 0), (404, 0), (424, 0))) + val results = response.getResults.asScala + + val firstNon424 = results.zipWithIndex.find { case (result, _) => + result.getStatusCode != 424 + } + + firstNon424 should be(defined) + firstNon424.get._2 should be(1) // index 1 -> NOT first operation → reject + firstNon424.get._1.getStatusCode should be(404) + } + + it should "return None when all results are 424" in { + val response = createMockBatchResponse(424, 0, List((424, 0), (424, 0))) + val results = response.getResults.asScala + + val firstNon424 = results.zipWithIndex.find { case (result, _) => + result.getStatusCode != 424 + } + + firstNon424 should be(empty) + } + + "shouldIgnoreOnRetry attempt guard" should "distinguish first attempt from retry" in { + // attemptNumber = 1 -> first attempt, shouldIgnoreOnRetry must return false + // attemptNumber > 1 -> retry, shouldIgnoreOnRetry may return true + // This test verifies the guard logic pattern + val attemptNumberFirstAttempt = 1 + val attemptNumberRetry = 2 + + (attemptNumberFirstAttempt <= 1) should be(true) // blocked + (attemptNumberRetry <= 1) should be(false) // allowed + } + + // ===================================================== + // originalItems Wrapping for Recovery Tests + // ===================================================== + + "originalItems recovery wrapping" should "preserve objectNode via upsert wrapper" in { + // When recovery extracts items from batches, it wraps TransactionalBulkItem + // as CosmosBulkOperations.getUpsertItemOperation to preserve the objectNode. + // This verifies getItem[ObjectNode] works on the wrapper. + val pk = new PartitionKey("user-A") + val objectNode = createObjectNode("doc1", "user-A") + + val wrapped = CosmosBulkOperations.getUpsertItemOperation(objectNode, pk) + + wrapped.getPartitionKeyValue should be(pk) + wrapped.getItem[ObjectNode] should not be null + wrapped.getItem[ObjectNode].get("id").asText() should be("doc1") + wrapped.getItem[ObjectNode].get("pk").asText() should be("user-A") + } + + it should "work for items that were originally deletes" in { + // Delete operations have null item bodies, but the originalItems + // preserve the original objectNode. The upsert wrapper preserves it. + val pk = new PartitionKey("user-B") + val objectNode = createObjectNode("doc-to-delete", "user-B") + + // The original delete has no body + val deleteBatch = CosmosBatch.createCosmosBatch(pk) + deleteBatch.deleteItemOperation("doc-to-delete") + deleteBatch.getOperations.get(0).getItem[ObjectNode] should be(null) // NPE source + + // But the recovery wrapping preserves the original objectNode + val wrapped = CosmosBulkOperations.getUpsertItemOperation(objectNode, pk) + wrapped.getItem[ObjectNode] should not be null + wrapped.getItem[ObjectNode].get("id").asText() should be("doc-to-delete") + } + + // ===================================================== + // isIdempotent Guard Tests + // ===================================================== + + "isIdempotent re-enqueue guard" should "have correct default value" in { + // isIdempotent defaults to true — safe for ItemOverwrite (upsert) + // Non-idempotent strategies (increment patch) will set this to false + val defaultIsIdempotent = true + defaultIsIdempotent should be(true) + } + + it should "block re-enqueue for non-idempotent operations" in { + // Pattern: if (!isIdempotent) skip re-enqueue + val isIdempotent = false + (!isIdempotent) should be(true) // would skip + } + + it should "allow re-enqueue for idempotent operations" in { + val isIdempotent = true + (!isIdempotent) should be(false) // would NOT skip + } + + "batchIsIdempotent computation" should "be true for non-patch strategies" in { + // All non-patch strategies are idempotent: upsert, create, delete, replace + // produce the same result on retry. + val nonPatchStrategies = List( + ItemWriteStrategy.ItemOverwrite, + ItemWriteStrategy.ItemAppend, + ItemWriteStrategy.ItemDelete, + ItemWriteStrategy.ItemDeleteIfNotModified, + ItemWriteStrategy.ItemOverwriteIfNotModified + ) + for (strategy <- nonPatchStrategies) { + val isIdempotent = strategy match { + case ItemWriteStrategy.ItemPatch | ItemWriteStrategy.ItemPatchIfExists => false // placeholder + case _ => true + } + isIdempotent should be(true) + } + } + + it should "be true for ItemPatch with only Set/Add/Replace/Remove operations" in { + // Patch operations like set, add, replace, remove are idempotent — + // applying them twice produces the same document state. + val hasIncrement = List( + CosmosPatchOperationTypes.Set, + CosmosPatchOperationTypes.Add, + CosmosPatchOperationTypes.Replace, + CosmosPatchOperationTypes.Remove + ).exists(_ == CosmosPatchOperationTypes.Increment) + hasIncrement should be(false) // no increment → idempotent + } + + it should "be false for ItemPatch with Increment operations" in { + // Increment is non-idempotent — double-applying corrupts counters. + // The batchIsIdempotent flag must be false when any column config uses Increment. + val operationTypes = List( + CosmosPatchOperationTypes.Set, + CosmosPatchOperationTypes.Increment, // non-idempotent + CosmosPatchOperationTypes.Replace + ) + val hasIncrement = operationTypes.exists(_ == CosmosPatchOperationTypes.Increment) + hasIncrement should be(true) // has increment → NOT idempotent → batchIsIdempotent = false + } + + // ===================================================== + // Duplicate PK Detection with String Keys + // (Verifies the hashCode fix — PartitionKey.toString() as set key) + // ===================================================== + + "PartitionKey String-based keying" should "work regardless of PartitionKey.hashCode() behavior" in { + // PartitionKey.hashCode() may use Object.hashCode() (identity-based), which violates + // the Java equals/hashCode contract. Our fix uses PartitionKey.toString() as the set key + // instead of PartitionKey directly. This test verifies that the String-based approach + // produces correct results whether or not the SDK fixes hashCode() in the future. + val pk1 = new PartitionKeyBuilder() + .add("tenant-A").add("user-1").add("session-1").build() + val pk2 = new PartitionKeyBuilder() + .add("tenant-A").add("user-1").add("session-1").build() + + // Value equality holds + pk1.equals(pk2) should be(true) + + // String-based keying always detects duplicates, regardless of hashCode() behavior + pk1.toString should be(pk2.toString) + val set = ConcurrentHashMap.newKeySet[String]() + set.add(pk1.toString) should be(true) // first add -> true + set.add(pk2.toString) should be(false) // duplicate detected via String equality -> false + } + + "duplicate PK detection with String keys (C10 fix)" should "detect value-equal HPK partition keys" in { + // we use PartitionKey.toString() as the set key. + // toString() returns deterministic JSON: '["tenant-A","user-1","session-1"]' + val pk1 = new PartitionKeyBuilder() + .add("tenant-A").add("user-1").add("session-1").build() + val pk2 = new PartitionKeyBuilder() + .add("tenant-A").add("user-1").add("session-1").build() + + pk1.toString should be(pk2.toString) + + val set = ConcurrentHashMap.newKeySet[String]() + set.add(pk1.toString) should be(true) // first add -> true + set.add(pk2.toString) should be(false) // duplicate detected -> false + } + + it should "correctly distinguish different HPK values" in { + val pk1 = new PartitionKeyBuilder() + .add("tenant-A").add("user-1").add("session-1").build() + val pk2 = new PartitionKeyBuilder() + .add("tenant-A").add("user-1").add("session-2").build() // different session + + pk1.toString should not be pk2.toString + + val set = ConcurrentHashMap.newKeySet[String]() + set.add(pk1.toString) should be(true) + set.add(pk2.toString) should be(true) // different PK -> no conflict + } + + it should "work for single partition keys too" in { + val pk1 = new PartitionKey("Seattle") + val pk2 = new PartitionKey("Seattle") + + pk1.toString should be(pk2.toString) + + val set = ConcurrentHashMap.newKeySet[String]() + set.add(pk1.toString) should be(true) + set.add(pk2.toString) should be(false) // duplicate detected + } + + // ===================================================== + // isAllowedProperty HPK False Positive Fix + // ===================================================== + + "isAllowedProperty " should "allow patching /user when PK paths are /tenantId/userId/sessionId" in { + // : List("/tenantId", "/userId", "/sessionId").contains("/user") + // -> false -> ALLOWED (CORRECT) + val pkDef = new PartitionKeyDefinition() + val paths = new java.util.ArrayList[String]() + paths.add("/tenantId") + paths.add("/userId") + paths.add("/sessionId") + pkDef.setPaths(paths) + + // The fix uses Java List.contains() — exact match, not substring + pkDef.getPaths.contains("/user") should be(false) // not a PK path + pkDef.getPaths.contains("/tenant") should be(false) // not a PK path + pkDef.getPaths.contains("/session") should be(false) // not a PK path + pkDef.getPaths.contains("/tenantId") should be(true) // IS a PK path + pkDef.getPaths.contains("/userId") should be(true) // IS a PK path + pkDef.getPaths.contains("/sessionId") should be(true) // IS a PK path + } + + it should "work for single partition key definitions" in { + val pkDef = new PartitionKeyDefinition() + val paths = new java.util.ArrayList[String]() + paths.add("/pk") + pkDef.setPaths(paths) + + pkDef.getPaths.contains("/pk") should be(true) + pkDef.getPaths.contains("/p") should be(false) // substring — should NOT match + pkDef.getPaths.contains("/pkId") should be(false) // superstring — should NOT match + } + + it should "still block actual PK paths from patching" in { + val pkDef = new PartitionKeyDefinition() + val paths = new java.util.ArrayList[String]() + paths.add("/tenantId") + paths.add("/userId") + paths.add("/sessionId") + pkDef.setPaths(paths) + + // These ARE PK paths — getPaths.contains returns true → blocked + pkDef.getPaths.contains("/tenantId") should be(true) + pkDef.getPaths.contains("/userId") should be(true) + pkDef.getPaths.contains("/sessionId") should be(true) + } + + it should "block system properties regardless of PK definition" in { + // System properties (_rid, _self, _etag, _attachments, _ts) and id + // are always immutable and must be blocked from patching. + // This is tested via the CosmosPatchHelper constants, not getPaths. + val systemProps = Set("_rid", "_self", "_etag", "_attachments", "_ts") + systemProps.contains("_rid") should be(true) + systemProps.contains("_etag") should be(true) + systemProps.contains("_ts") should be(true) + // "id" is also immutable + "id" should be("id") + } + + it should "handle edge case where field name is a suffix of a PK path" in { + // e.g., field "/nantId" is a suffix of "/tenantId" + // New code: List("/tenantId", "/userId", "/sessionId").contains("/nantId") → false → ALLOWED + val pkDef = new PartitionKeyDefinition() + val paths = new java.util.ArrayList[String]() + paths.add("/tenantId") + paths.add("/userId") + paths.add("/sessionId") + pkDef.setPaths(paths) + + pkDef.getPaths.contains("/nantId") should be(false) // suffix of /tenantId + pkDef.getPaths.contains("/erId") should be(false) // suffix of /userId + pkDef.getPaths.contains("/ionId") should be(false) // suffix of /sessionId + } + + // ===================================================== + // Batch Marker Document Tests + // ===================================================== + + "buildMarkerDocument pattern" should "create a minimal marker with id, ttl, and PK fields" in { + val om = new ObjectMapper() + val businessItem = om.createObjectNode() + businessItem.put("id", "doc-1") + businessItem.put("tenantId", "Contoso") + businessItem.put("userId", "alice") + businessItem.put("sessionId", "sess-99") + businessItem.put("score", 42) + + // Simulate buildMarkerDocument logic + val markerId = "__tbw:12345:3:1" + val markerTtlSeconds = 86400 + val partitionKeyPaths = List("/tenantId", "/userId", "/sessionId") + + val markerNode = om.createObjectNode() + markerNode.put("id", markerId) + markerNode.put("ttl", markerTtlSeconds) + partitionKeyPaths.foreach(path => { + val fieldName = path.stripPrefix("/") + val value = businessItem.get(fieldName) + if (value != null) { + markerNode.set(fieldName, value.deepCopy()) + } + }) + + // Verify marker has id + ttl + PK fields only (no business fields like "score") + markerNode.get("id").asText() should be("__tbw:12345:3:1") + markerNode.get("ttl").asInt() should be(86400) + markerNode.get("tenantId").asText() should be("Contoso") + markerNode.get("userId").asText() should be("alice") + markerNode.get("sessionId").asText() should be("sess-99") + markerNode.has("score") should be(false) // business field NOT in marker + } + + "marker ID" should "be deterministic for the same jobRunId, sparkPartitionId, and batchSeq" in { + val jobRunId = "task-attempt-12345" + val sparkPartitionId = 3 + val batchSeq = 17L + + val id1 = s"__tbw:$jobRunId:$sparkPartitionId:$batchSeq" + val id2 = s"__tbw:$jobRunId:$sparkPartitionId:$batchSeq" + + id1 should be(id2) + id1 should be("__tbw:task-attempt-12345:3:17") + } + + it should "be different for different batchSeq values" in { + val id1 = s"__tbw:job1:0:1" + val id2 = s"__tbw:job1:0:2" + + id1 should not be id2 + } + + it should "be different for different jobRunIds" in { + val id1 = s"__tbw:job-alpha:0:1" + val id2 = s"__tbw:job-beta:0:1" + + id1 should not be id2 + } + + // ===================================================== + // 100-Item Boundary (Marker Skip) + // ===================================================== + + "C15 marker skip" should "add marker when batch has fewer than 100 items" in { + val pk = new PartitionKey("user-A") + val batch = CosmosBatch.createCosmosBatch(pk) + + // Add 99 business items + for (i <- 1 to 99) { + batch.upsertItemOperation(createObjectNode(s"doc-$i", "user-A")) + } + batch.getOperations.size() should be(99) + + // Adding marker makes it 100 — within server limit + batch.upsertItemOperation(createObjectNode("__tbw:test:0:1", "user-A")) + batch.getOperations.size() should be(100) // exactly at limit — OK + } + + it should "skip marker when batch already has 100 items" in { + val pk = new PartitionKey("user-A") + val batch = CosmosBatch.createCosmosBatch(pk) + + // Add 100 business items + for (i <- 1 to 100) { + batch.upsertItemOperation(createObjectNode(s"doc-$i", "user-A")) + } + batch.getOperations.size() should be(100) + + // bulkItemsList.size() < 100 -> false -> skip marker + val shouldAddMarker = batch.getOperations.size() < 100 + shouldAddMarker should be(false) + } + + "marker position" should "always be the last operation in the batch" in { + val pk = new PartitionKey("user-A") + val batch = CosmosBatch.createCosmosBatch(pk) + + // Add business items first + batch.createItemOperation(createObjectNode("doc-1", "user-A")) + batch.upsertItemOperation(createObjectNode("doc-2", "user-A")) + batch.deleteItemOperation("doc-3") + + // Add marker last (same as production code: upsert with marker ObjectNode) + val markerNode = objectMapper.createObjectNode() + markerNode.put("id", "__tbw:test:0:1") + markerNode.put("ttl", 86400) + markerNode.put("pk", "user-A") + batch.upsertItemOperation(markerNode) + + val ops = batch.getOperations + ops.size() should be(4) + // Business items preserve their original order + ops.get(0).getOperationType.toString should be("CREATE") + ops.get(1).getOperationType.toString should be("UPSERT") + ops.get(2).getOperationType.toString should be("DELETE") + // Marker is the last operation and is an UPSERT + ops.get(3).getOperationType.toString should be("UPSERT") + } + + // ===================================================== + // Marker Document Edge Cases + // ===================================================== + + "buildMarkerDocument pattern" should "handle missing PK field in business item" in { + // If a business item is missing a PK field (e.g., HPK with /tenantId/userId/sessionId + // but the document has no "sessionId"), the marker should omit that field too. + val om = new ObjectMapper() + val businessItem = om.createObjectNode() + businessItem.put("id", "doc-1") + businessItem.put("tenantId", "Contoso") + businessItem.put("userId", "alice") + // sessionId is MISSING + + val partitionKeyPaths = List("/tenantId", "/userId", "/sessionId") + val markerNode = om.createObjectNode() + markerNode.put("id", "__tbw:job:0:1") + markerNode.put("ttl", 86400) + partitionKeyPaths.foreach(path => { + val fieldName = path.stripPrefix("/") + val value = businessItem.get(fieldName) + if (value != null) { + markerNode.set(fieldName, value.deepCopy()) + } + }) + + markerNode.get("tenantId").asText() should be("Contoso") + markerNode.get("userId").asText() should be("alice") + markerNode.has("sessionId") should be(false) // missing in business item → missing in marker + } + + it should "work with single partition key" in { + val om = new ObjectMapper() + val businessItem = om.createObjectNode() + businessItem.put("id", "doc-1") + businessItem.put("pk", "Seattle") + businessItem.put("temperature", 72) + + val partitionKeyPaths = List("/pk") + val markerNode = om.createObjectNode() + markerNode.put("id", "__tbw:job:0:1") + markerNode.put("ttl", 86400) + partitionKeyPaths.foreach(path => { + val fieldName = path.stripPrefix("/") + val value = businessItem.get(fieldName) + if (value != null) { + markerNode.set(fieldName, value.deepCopy()) + } + }) + + markerNode.get("id").asText() should be("__tbw:job:0:1") + markerNode.get("ttl").asInt() should be(86400) + markerNode.get("pk").asText() should be("Seattle") + markerNode.has("temperature") should be(false) // business field excluded + } + + it should "not mutate the original business item" in { + val om = new ObjectMapper() + val businessItem = om.createObjectNode() + businessItem.put("id", "doc-1") + businessItem.put("pk", "user-A") + businessItem.put("score", 42) + + val partitionKeyPaths = List("/pk") + val markerNode = om.createObjectNode() + markerNode.put("id", "__tbw:job:0:1") + markerNode.put("ttl", 86400) + partitionKeyPaths.foreach(path => { + val fieldName = path.stripPrefix("/") + val value = businessItem.get(fieldName) + if (value != null) { + markerNode.set(fieldName, value.deepCopy()) + } + }) + + // Original business item is unchanged + businessItem.get("id").asText() should be("doc-1") + businessItem.get("pk").asText() should be("user-A") + businessItem.get("score").asInt() should be(42) + businessItem.has("ttl") should be(false) // marker's ttl was NOT added to business item + } + + // ===================================================== + // Marker Verification Outcome Pattern Tests + // ===================================================== + + "MarkerVerificationOutcome pattern" should "distinguish three outcomes" in { + // Verify the sealed trait / case object pattern used in verifyBatchCommit + // This tests the pattern matching logic — not the actual Cosmos DB call + sealed trait TestOutcome + case object TestCommitted extends TestOutcome + case object TestNotCommitted extends TestOutcome + case object TestInconclusive extends TestOutcome + + def simulateVerification(statusCode: Int): TestOutcome = statusCode match { + case 200 => TestCommitted // marker present → batch committed + case 404 => TestNotCommitted // marker absent → batch did not commit + case _ => TestInconclusive // transient error → inconclusive + } + + simulateVerification(200) should be(TestCommitted) + simulateVerification(404) should be(TestNotCommitted) + simulateVerification(408) should be(TestInconclusive) // Request Timeout + simulateVerification(503) should be(TestInconclusive) // Service Unavailable + simulateVerification(500) should be(TestInconclusive) // Internal Server Error + } + + // ===================================================== + // Inconclusive Retry Eligibility Tests + // ===================================================== + + "Inconclusive retry eligibility" should "depend on attempt budget, not original batch status code" in { + // scenario: ItemAppend batch gets 409 on retry (shouldIgnore-eligible). + // Marker verification fails transiently → Inconclusive. + // The retry decision must be based on attemptNumber < maxRetryCount, + + val maxRetryCount = 10 + + // 409 is NOT transient + Exceptions.canBeTransientFailure(409, 0) should be(false) + + // Attempt 2 of 10: should be eligible for retry (budget remains) + val attemptNumber = 2 + (attemptNumber < maxRetryCount) should be(true) + + // Attempt 10 of 10: should NOT be eligible (budget exhausted) + val lastAttempt = 10 + (lastAttempt < maxRetryCount) should be(false) + + // Also verify that non-shouldIgnore-eligible transient codes are irrelevant here — + // the Inconclusive path doesn't care what the original batch status was, + // only whether there is retry budget left. + // 404/0 for ItemDelete (shouldIgnore-eligible, not transient for non-ItemOverwrite): + Exceptions.canBeTransientFailure(404, 0) should be(false) + // Even with a non-transient original status, retry is allowed if budget remains: + val midAttempt = 5 + (midAttempt < maxRetryCount) should be(true) + } +} +//scalastyle:on null +//scalastyle:on magic.number +//scalastyle:on multiple.string.literals +