From dbead722cbfcdd953564bc745ab54b9e6aea65b4 Mon Sep 17 00:00:00 2001 From: Ranadeep Singh Date: Tue, 23 Jun 2026 16:29:57 -0700 Subject: [PATCH 1/5] fix: route AnalyzeText document errors to errorCol Move Azure AI Language document-level errors returned inside HTTP 200 AnalyzeText responses from the response payload into the configured error column after auto-batch flattening. Preserve transport error precedence and add a no-network regression test for mixed document success/error responses. AB#4638662 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../ml/services/language/AnalyzeText.scala | 58 ++++++++++--- .../services/language/AnalyzeTextSuite.scala | 87 +++++++++++++++++++ 2 files changed, 134 insertions(+), 11 deletions(-) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeText.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeText.scala index d3577657890..0cd1eb155a4 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeText.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeText.scala @@ -4,16 +4,18 @@ package com.microsoft.azure.synapse.ml.services.language import com.microsoft.azure.synapse.ml.logging.{ FeatureNames, SynapseMLLogging } +import com.microsoft.azure.synapse.ml.io.http.ErrorUtils import com.microsoft.azure.synapse.ml.param.ServiceParam import com.microsoft.azure.synapse.ml.services._ import com.microsoft.azure.synapse.ml.services.text.{ TADocument, TextAnalyticsAutoBatch } -import com.microsoft.azure.synapse.ml.stages.{ FixedMiniBatchTransformer, FlattenBatch, HasBatchSize, UDFTransformer } +import com.microsoft.azure.synapse.ml.stages.{ FixedMiniBatchTransformer, FlattenBatch, HasBatchSize, Lambda, + UDFTransformer } import org.apache.http.entity.{ AbstractHttpEntity, StringEntity } import org.apache.spark.injections.UDFUtils import org.apache.spark.ml.{ ComplexParamsReadable, NamespaceInjections, PipelineModel } import org.apache.spark.ml.param.{ Param, ParamValidators } import org.apache.spark.ml.util.Identifiable -import org.apache.spark.sql.Row +import org.apache.spark.sql.{ Column, Row, functions => F } import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.types.{ ArrayType, DataType, StructType } import spray.json._ @@ -258,19 +260,47 @@ class AnalyzeText(override val uid: String) extends CognitiveServicesBase(uid) } } - protected def postprocessResponseUdf: UserDefinedFunction = { + private def postprocessedOutputType: StructType = { val responseType = responseDataType.asInstanceOf[StructType] val results = responseType("results").dataType.asInstanceOf[StructType] - val outputType = ArrayType( - new StructType() - .add("statistics", results("statistics").dataType) - .add("documents", results("documents").dataType.asInstanceOf[ArrayType].elementType) - .add("errors", results("errors").dataType.asInstanceOf[ArrayType].elementType) - .add("modelVersion", results("modelVersion").dataType) - ) + new StructType() + .add("statistics", results("statistics").dataType) + .add("documents", results("documents").dataType.asInstanceOf[ArrayType].elementType) + .add("errors", results("errors").dataType.asInstanceOf[ArrayType].elementType) + .add("modelVersion", results("modelVersion").dataType) + } + + protected def postprocessResponseUdf: UserDefinedFunction = { + val outputType = ArrayType(postprocessedOutputType) UDFUtils.oldUdf(postprocessResponse _, outputType) } + private def responseErrorToErrorCol(error: Column): Column = { + F.when(error.isNotNull, F.struct( + F.to_json(error).as("response"), + F.lit(null).cast(ErrorUtils.ErrorSchema("status").dataType).as("status") // scalastyle:ignore null + )) + } + + private def outputWithoutResponseError(output: Column): Column = { + val outputType = postprocessedOutputType + F.when(output.isNotNull, F.struct( + output.getField("statistics").as("statistics"), + output.getField("documents").as("documents"), + F.lit(null).cast(outputType("errors").dataType).as("errors"), // scalastyle:ignore null + output.getField("modelVersion").as("modelVersion") + )).otherwise(F.lit(null).cast(outputType)) // scalastyle:ignore null + } + + private def moveResponseErrorsToErrorCol( + dataset: org.apache.spark.sql.Dataset[_]): org.apache.spark.sql.DataFrame = { + val df = dataset.toDF + val output = F.col(getOutputCol) + val responseError = output.getField("errors") + df.withColumn(getErrorCol, F.coalesce(F.col(getErrorCol), responseErrorToErrorCol(responseError))) + .withColumn(getOutputCol, F.when(responseError.isNotNull, outputWithoutResponseError(output)).otherwise(output)) + } + override protected def getInternalTransformer(schema: StructType): PipelineModel = { val batcher = if (shouldAutoBatch(schema)) { @@ -293,8 +323,14 @@ class AnalyzeText(override val uid: String) extends CognitiveServicesBase(uid) None } + val moveResponseErrors = if (shouldAutoBatch(schema)) { + Some(Lambda(moveResponseErrorsToErrorCol _).setTransformSchema((schema: StructType) => schema)) + } else { + None + } + NamespaceInjections.pipelineModel( - Array(batcher, Some(pipe), Some(postprocess), flatten).flatten + Array(batcher, Some(pipe), Some(postprocess), flatten, moveResponseErrors).flatten ) } diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextSuite.scala index 9efe369b241..8c0d9fee1ee 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextSuite.scala @@ -3,13 +3,100 @@ package com.microsoft.azure.synapse.ml.services.language +import com.microsoft.azure.synapse.ml.core.test.base.TestBase import com.microsoft.azure.synapse.ml.services.text.{SentimentAssessment, TextEndpoint} import com.microsoft.azure.synapse.ml.core.test.fuzzing.{TestObject, TransformerFuzzing} +import com.microsoft.azure.synapse.ml.io.http.{HTTPRequestData, HTTPResponseData, HTTPSchema} +import org.apache.http.impl.client.CloseableHttpClient import org.apache.spark.ml.util.MLReadable import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.{col, flatten, map} import org.scalactic.{Equality, TolerantNumerics} +object AnalyzeTextErrorRoutingTestData extends Serializable { + val ResponseWithDocumentError: String = + """{ + | "kind": "PiiEntityRecognition", + | "results": { + | "documents": [ + | { + | "id": "0", + | "redactedText": "My SSN is ***********", + | "entities": [], + | "warnings": [], + | "statistics": null + | } + | ], + | "errors": [ + | { + | "id": "1", + | "error": { + | "code": "InvalidArgument", + | "message": "Document exceeds the service character limit.", + | "target": "documents.1", + | "details": null, + | "innererror": { + | "code": "InvalidDocument", + | "innerError": "DocumentTooLong" + | } + | } + | } + | ], + | "modelVersion": "test", + | "statistics": { + | "documentsCount": 2, + | "validDocumentsCount": 1, + | "erroneousDocumentsCount": 1, + | "transactionsCount": 2 + | } + | } + |}""".stripMargin + + def okResponseHandler( + client: CloseableHttpClient, + request: HTTPRequestData): HTTPResponseData = { + HTTPSchema.stringToResponse(ResponseWithDocumentError, 200, "OK") + } +} + +class AnalyzeTextErrorRoutingSuite extends TestBase { + import spark.implicits._ + + test("AnalyzeText moves document-level 200 response errors to errorCol") { + val model = new AnalyzeText() + .setSubscriptionKey("unused") + .setLocation("eastus") + .setTextCol("text") + .setLanguage("en") + .setKind("PiiEntityRecognition") + .setOutputCol("response") + .setErrorCol("error") + .setHandler(AnalyzeTextErrorRoutingTestData.okResponseHandler _) + + val rows = model.transform(Seq("valid text", "too long").toDF("text").coalesce(1)) + .select("response", "error") + .collect() + + assert(rows.length == 2) + + val successResponse = rows.head.getAs[Row]("response") + assert(successResponse.getAs[Row]("documents") != null) + assert(successResponse.getAs[Row]("errors") == null) + assert(rows.head.getAs[Row]("error") == null) + + val failedResponse = rows(1).getAs[Row]("response") + assert(failedResponse.getAs[Row]("documents") == null) + assert(failedResponse.getAs[Row]("errors") == null) + + val error = rows(1).getAs[Row]("error") + assert(error != null) + val errorResponse = error.getAs[String]("response") + assert(errorResponse.contains("InvalidArgument")) + assert(errorResponse.contains("Document exceeds the service character limit.")) + assert(error.getAs[Row]("status") == null) + } +} + class EntityLinkingSuite extends TransformerFuzzing[AnalyzeText] with TextEndpoint { override val compareDataInSerializationTest: Boolean = false From 2608b447d62baa43de3380f5a9dd6ba801f66bb0 Mon Sep 17 00:00:00 2001 From: Ranadeep Singh Date: Tue, 23 Jun 2026 17:04:27 -0700 Subject: [PATCH 2/5] ci: pin PR validation sbt launcher Use the sbt launcher version from project/build.properties instead of installing the latest apt sbt package. This keeps the JDK 11 PR validation job on the repository's sbt 1.10.11 launcher and avoids sbt 2.x rejecting JDK 11 before scalastyle can run. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .github/workflows/pr-validation.yml | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/.github/workflows/pr-validation.yml b/.github/workflows/pr-validation.yml index 85bfae3c347..9a5f0ce332c 100644 --- a/.github/workflows/pr-validation.yml +++ b/.github/workflows/pr-validation.yml @@ -41,10 +41,16 @@ jobs: - name: Install sbt run: | - echo "deb https://repo.scala-sbt.org/scalasbt/debian all main" | sudo tee /etc/apt/sources.list.d/sbt.list - curl -sL "https://keyserver.ubuntu.com/pks/lookup?op=get&search=0x2EE0EA64E40A89B84B2DF73499E82A75642AC823" | sudo apt-key add - sudo apt-get update -q - sudo apt-get install -yq sbt + SBT_VERSION="$(sed -n 's/^sbt.version *= *//p' project/build.properties | tr -d ' ')" + mkdir -p "$HOME/.local/bin" + curl -L -o "$HOME/.local/bin/sbt-launch.jar" \ + "https://repo1.maven.org/maven2/org/scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch-${SBT_VERSION}.jar" + cat > "$HOME/.local/bin/sbt" <> "$GITHUB_PATH" - name: Scalastyle check run: sbt scalastyle test:scalastyle From b381092f89ab870f3dc6bb532638105bbfe144fd Mon Sep 17 00:00:00 2001 From: Ranadeep Singh Date: Tue, 23 Jun 2026 17:08:37 -0700 Subject: [PATCH 3/5] ci: use pinned sbt wrapper in PR validation Invoke the downloaded sbt launcher explicitly so the GitHub runner does not resolve its preinstalled sbt 2.x binary under JDK 11. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .github/workflows/pr-validation.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pr-validation.yml b/.github/workflows/pr-validation.yml index 9a5f0ce332c..cff1a81b66a 100644 --- a/.github/workflows/pr-validation.yml +++ b/.github/workflows/pr-validation.yml @@ -53,7 +53,7 @@ jobs: echo "$HOME/.local/bin" >> "$GITHUB_PATH" - name: Scalastyle check - run: sbt scalastyle test:scalastyle + run: "$HOME/.local/bin/sbt" scalastyle test:scalastyle - name: Compile - run: sbt compile test:compile + run: "$HOME/.local/bin/sbt" compile test:compile From e199c6492d498dee25e8d59b4ddcd9fa0fcccfc4 Mon Sep 17 00:00:00 2001 From: Ranadeep Singh Date: Tue, 23 Jun 2026 17:11:58 -0700 Subject: [PATCH 4/5] ci: prefer pinned sbt on PATH Keep PR validation commands as plain sbt while placing the repository-version launcher first on PATH for subsequent workflow steps. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .github/workflows/pr-validation.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pr-validation.yml b/.github/workflows/pr-validation.yml index cff1a81b66a..0c3776bba5b 100644 --- a/.github/workflows/pr-validation.yml +++ b/.github/workflows/pr-validation.yml @@ -51,9 +51,11 @@ jobs: EOF chmod +x "$HOME/.local/bin/sbt" echo "$HOME/.local/bin" >> "$GITHUB_PATH" + export PATH="$HOME/.local/bin:$PATH" + sbt sbtVersion - name: Scalastyle check - run: "$HOME/.local/bin/sbt" scalastyle test:scalastyle + run: sbt scalastyle test:scalastyle - name: Compile - run: "$HOME/.local/bin/sbt" compile test:compile + run: sbt compile test:compile From 5e708a7db3108e9ecb1c33d44d9ecb7ce3751c04 Mon Sep 17 00:00:00 2001 From: Ranadeep Singh Date: Tue, 23 Jun 2026 20:36:57 -0700 Subject: [PATCH 5/5] test: avoid ordering assumption in AnalyzeText error test Partition collected rows by error nullability instead of relying on collect order, addressing PR review feedback about Spark DataFrames being unordered. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../ml/services/language/AnalyzeTextSuite.scala | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextSuite.scala index 8c0d9fee1ee..84b6646428a 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextSuite.scala @@ -78,17 +78,19 @@ class AnalyzeTextErrorRoutingSuite extends TestBase { .collect() assert(rows.length == 2) + val (failedRows, successRows) = rows.partition(row => row.getAs[Row]("error") != null) + assert(successRows.length == 1) + assert(failedRows.length == 1) - val successResponse = rows.head.getAs[Row]("response") + val successResponse = successRows.head.getAs[Row]("response") assert(successResponse.getAs[Row]("documents") != null) assert(successResponse.getAs[Row]("errors") == null) - assert(rows.head.getAs[Row]("error") == null) - val failedResponse = rows(1).getAs[Row]("response") + val failedResponse = failedRows.head.getAs[Row]("response") assert(failedResponse.getAs[Row]("documents") == null) assert(failedResponse.getAs[Row]("errors") == null) - val error = rows(1).getAs[Row]("error") + val error = failedRows.head.getAs[Row]("error") assert(error != null) val errorResponse = error.getAs[String]("response") assert(errorResponse.contains("InvalidArgument"))