diff --git a/.github/workflows/pr-validation.yml b/.github/workflows/pr-validation.yml index 85bfae3c34..0c3776bba5 100644 --- a/.github/workflows/pr-validation.yml +++ b/.github/workflows/pr-validation.yml @@ -41,10 +41,18 @@ jobs: - name: Install sbt run: | - echo "deb https://repo.scala-sbt.org/scalasbt/debian all main" | sudo tee /etc/apt/sources.list.d/sbt.list - curl -sL "https://keyserver.ubuntu.com/pks/lookup?op=get&search=0x2EE0EA64E40A89B84B2DF73499E82A75642AC823" | sudo apt-key add - sudo apt-get update -q - sudo apt-get install -yq sbt + SBT_VERSION="$(sed -n 's/^sbt.version *= *//p' project/build.properties | tr -d ' ')" + mkdir -p "$HOME/.local/bin" + curl -L -o "$HOME/.local/bin/sbt-launch.jar" \ + "https://repo1.maven.org/maven2/org/scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch-${SBT_VERSION}.jar" + cat > "$HOME/.local/bin/sbt" <> "$GITHUB_PATH" + export PATH="$HOME/.local/bin:$PATH" + sbt sbtVersion - name: Scalastyle check run: sbt scalastyle test:scalastyle diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeText.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeText.scala index d357765789..0cd1eb155a 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeText.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeText.scala @@ -4,16 +4,18 @@ package com.microsoft.azure.synapse.ml.services.language import com.microsoft.azure.synapse.ml.logging.{ FeatureNames, SynapseMLLogging } +import com.microsoft.azure.synapse.ml.io.http.ErrorUtils import com.microsoft.azure.synapse.ml.param.ServiceParam import com.microsoft.azure.synapse.ml.services._ import com.microsoft.azure.synapse.ml.services.text.{ TADocument, TextAnalyticsAutoBatch } -import com.microsoft.azure.synapse.ml.stages.{ FixedMiniBatchTransformer, FlattenBatch, HasBatchSize, UDFTransformer } +import com.microsoft.azure.synapse.ml.stages.{ FixedMiniBatchTransformer, FlattenBatch, HasBatchSize, Lambda, + UDFTransformer } import org.apache.http.entity.{ AbstractHttpEntity, StringEntity } import org.apache.spark.injections.UDFUtils import org.apache.spark.ml.{ ComplexParamsReadable, NamespaceInjections, PipelineModel } import org.apache.spark.ml.param.{ Param, ParamValidators } import org.apache.spark.ml.util.Identifiable -import org.apache.spark.sql.Row +import org.apache.spark.sql.{ Column, Row, functions => F } import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.types.{ ArrayType, DataType, StructType } import spray.json._ @@ -258,19 +260,47 @@ class AnalyzeText(override val uid: String) extends CognitiveServicesBase(uid) } } - protected def postprocessResponseUdf: UserDefinedFunction = { + private def postprocessedOutputType: StructType = { val responseType = responseDataType.asInstanceOf[StructType] val results = responseType("results").dataType.asInstanceOf[StructType] - val outputType = ArrayType( - new StructType() - .add("statistics", results("statistics").dataType) - .add("documents", results("documents").dataType.asInstanceOf[ArrayType].elementType) - .add("errors", results("errors").dataType.asInstanceOf[ArrayType].elementType) - .add("modelVersion", results("modelVersion").dataType) - ) + new StructType() + .add("statistics", results("statistics").dataType) + .add("documents", results("documents").dataType.asInstanceOf[ArrayType].elementType) + .add("errors", results("errors").dataType.asInstanceOf[ArrayType].elementType) + .add("modelVersion", results("modelVersion").dataType) + } + + protected def postprocessResponseUdf: UserDefinedFunction = { + val outputType = ArrayType(postprocessedOutputType) UDFUtils.oldUdf(postprocessResponse _, outputType) } + private def responseErrorToErrorCol(error: Column): Column = { + F.when(error.isNotNull, F.struct( + F.to_json(error).as("response"), + F.lit(null).cast(ErrorUtils.ErrorSchema("status").dataType).as("status") // scalastyle:ignore null + )) + } + + private def outputWithoutResponseError(output: Column): Column = { + val outputType = postprocessedOutputType + F.when(output.isNotNull, F.struct( + output.getField("statistics").as("statistics"), + output.getField("documents").as("documents"), + F.lit(null).cast(outputType("errors").dataType).as("errors"), // scalastyle:ignore null + output.getField("modelVersion").as("modelVersion") + )).otherwise(F.lit(null).cast(outputType)) // scalastyle:ignore null + } + + private def moveResponseErrorsToErrorCol( + dataset: org.apache.spark.sql.Dataset[_]): org.apache.spark.sql.DataFrame = { + val df = dataset.toDF + val output = F.col(getOutputCol) + val responseError = output.getField("errors") + df.withColumn(getErrorCol, F.coalesce(F.col(getErrorCol), responseErrorToErrorCol(responseError))) + .withColumn(getOutputCol, F.when(responseError.isNotNull, outputWithoutResponseError(output)).otherwise(output)) + } + override protected def getInternalTransformer(schema: StructType): PipelineModel = { val batcher = if (shouldAutoBatch(schema)) { @@ -293,8 +323,14 @@ class AnalyzeText(override val uid: String) extends CognitiveServicesBase(uid) None } + val moveResponseErrors = if (shouldAutoBatch(schema)) { + Some(Lambda(moveResponseErrorsToErrorCol _).setTransformSchema((schema: StructType) => schema)) + } else { + None + } + NamespaceInjections.pipelineModel( - Array(batcher, Some(pipe), Some(postprocess), flatten).flatten + Array(batcher, Some(pipe), Some(postprocess), flatten, moveResponseErrors).flatten ) } diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextSuite.scala index 9efe369b24..84b6646428 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/language/AnalyzeTextSuite.scala @@ -3,13 +3,102 @@ package com.microsoft.azure.synapse.ml.services.language +import com.microsoft.azure.synapse.ml.core.test.base.TestBase import com.microsoft.azure.synapse.ml.services.text.{SentimentAssessment, TextEndpoint} import com.microsoft.azure.synapse.ml.core.test.fuzzing.{TestObject, TransformerFuzzing} +import com.microsoft.azure.synapse.ml.io.http.{HTTPRequestData, HTTPResponseData, HTTPSchema} +import org.apache.http.impl.client.CloseableHttpClient import org.apache.spark.ml.util.MLReadable import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.{col, flatten, map} import org.scalactic.{Equality, TolerantNumerics} +object AnalyzeTextErrorRoutingTestData extends Serializable { + val ResponseWithDocumentError: String = + """{ + | "kind": "PiiEntityRecognition", + | "results": { + | "documents": [ + | { + | "id": "0", + | "redactedText": "My SSN is ***********", + | "entities": [], + | "warnings": [], + | "statistics": null + | } + | ], + | "errors": [ + | { + | "id": "1", + | "error": { + | "code": "InvalidArgument", + | "message": "Document exceeds the service character limit.", + | "target": "documents.1", + | "details": null, + | "innererror": { + | "code": "InvalidDocument", + | "innerError": "DocumentTooLong" + | } + | } + | } + | ], + | "modelVersion": "test", + | "statistics": { + | "documentsCount": 2, + | "validDocumentsCount": 1, + | "erroneousDocumentsCount": 1, + | "transactionsCount": 2 + | } + | } + |}""".stripMargin + + def okResponseHandler( + client: CloseableHttpClient, + request: HTTPRequestData): HTTPResponseData = { + HTTPSchema.stringToResponse(ResponseWithDocumentError, 200, "OK") + } +} + +class AnalyzeTextErrorRoutingSuite extends TestBase { + import spark.implicits._ + + test("AnalyzeText moves document-level 200 response errors to errorCol") { + val model = new AnalyzeText() + .setSubscriptionKey("unused") + .setLocation("eastus") + .setTextCol("text") + .setLanguage("en") + .setKind("PiiEntityRecognition") + .setOutputCol("response") + .setErrorCol("error") + .setHandler(AnalyzeTextErrorRoutingTestData.okResponseHandler _) + + val rows = model.transform(Seq("valid text", "too long").toDF("text").coalesce(1)) + .select("response", "error") + .collect() + + assert(rows.length == 2) + val (failedRows, successRows) = rows.partition(row => row.getAs[Row]("error") != null) + assert(successRows.length == 1) + assert(failedRows.length == 1) + + val successResponse = successRows.head.getAs[Row]("response") + assert(successResponse.getAs[Row]("documents") != null) + assert(successResponse.getAs[Row]("errors") == null) + + val failedResponse = failedRows.head.getAs[Row]("response") + assert(failedResponse.getAs[Row]("documents") == null) + assert(failedResponse.getAs[Row]("errors") == null) + + val error = failedRows.head.getAs[Row]("error") + assert(error != null) + val errorResponse = error.getAs[String]("response") + assert(errorResponse.contains("InvalidArgument")) + assert(errorResponse.contains("Document exceeds the service character limit.")) + assert(error.getAs[Row]("status") == null) + } +} + class EntityLinkingSuite extends TransformerFuzzing[AnalyzeText] with TextEndpoint { override val compareDataInSerializationTest: Boolean = false