Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions .github/workflows/pr-validation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,18 @@ jobs:

- name: Install sbt
run: |
echo "deb https://repo.scala-sbt.org/scalasbt/debian all main" | sudo tee /etc/apt/sources.list.d/sbt.list
curl -sL "https://keyserver.ubuntu.com/pks/lookup?op=get&search=0x2EE0EA64E40A89B84B2DF73499E82A75642AC823" | sudo apt-key add
sudo apt-get update -q
sudo apt-get install -yq sbt
SBT_VERSION="$(sed -n 's/^sbt.version *= *//p' project/build.properties | tr -d ' ')"
mkdir -p "$HOME/.local/bin"
curl -L -o "$HOME/.local/bin/sbt-launch.jar" \
"https://repo1.maven.org/maven2/org/scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch-${SBT_VERSION}.jar"
cat > "$HOME/.local/bin/sbt" <<EOF
#!/usr/bin/env bash
exec java -Dsbt.version="${SBT_VERSION}" -jar "$HOME/.local/bin/sbt-launch.jar" "\$@"
EOF
chmod +x "$HOME/.local/bin/sbt"
echo "$HOME/.local/bin" >> "$GITHUB_PATH"
export PATH="$HOME/.local/bin:$PATH"
sbt sbtVersion

- name: Scalastyle check
run: sbt scalastyle test:scalastyle
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@
package com.microsoft.azure.synapse.ml.services.language

import com.microsoft.azure.synapse.ml.logging.{ FeatureNames, SynapseMLLogging }
import com.microsoft.azure.synapse.ml.io.http.ErrorUtils
import com.microsoft.azure.synapse.ml.param.ServiceParam
import com.microsoft.azure.synapse.ml.services._
import com.microsoft.azure.synapse.ml.services.text.{ TADocument, TextAnalyticsAutoBatch }
import com.microsoft.azure.synapse.ml.stages.{ FixedMiniBatchTransformer, FlattenBatch, HasBatchSize, UDFTransformer }
import com.microsoft.azure.synapse.ml.stages.{ FixedMiniBatchTransformer, FlattenBatch, HasBatchSize, Lambda,
UDFTransformer }
import org.apache.http.entity.{ AbstractHttpEntity, StringEntity }
import org.apache.spark.injections.UDFUtils
import org.apache.spark.ml.{ ComplexParamsReadable, NamespaceInjections, PipelineModel }
import org.apache.spark.ml.param.{ Param, ParamValidators }
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.Row
import org.apache.spark.sql.{ Column, Row, functions => F }
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.types.{ ArrayType, DataType, StructType }
import spray.json._
Expand Down Expand Up @@ -258,19 +260,47 @@ class AnalyzeText(override val uid: String) extends CognitiveServicesBase(uid)
}
}

protected def postprocessResponseUdf: UserDefinedFunction = {
private def postprocessedOutputType: StructType = {
val responseType = responseDataType.asInstanceOf[StructType]
val results = responseType("results").dataType.asInstanceOf[StructType]
val outputType = ArrayType(
new StructType()
.add("statistics", results("statistics").dataType)
.add("documents", results("documents").dataType.asInstanceOf[ArrayType].elementType)
.add("errors", results("errors").dataType.asInstanceOf[ArrayType].elementType)
.add("modelVersion", results("modelVersion").dataType)
)
new StructType()
.add("statistics", results("statistics").dataType)
.add("documents", results("documents").dataType.asInstanceOf[ArrayType].elementType)
.add("errors", results("errors").dataType.asInstanceOf[ArrayType].elementType)
.add("modelVersion", results("modelVersion").dataType)
}

protected def postprocessResponseUdf: UserDefinedFunction = {
val outputType = ArrayType(postprocessedOutputType)
UDFUtils.oldUdf(postprocessResponse _, outputType)
}

private def responseErrorToErrorCol(error: Column): Column = {
F.when(error.isNotNull, F.struct(
F.to_json(error).as("response"),
F.lit(null).cast(ErrorUtils.ErrorSchema("status").dataType).as("status") // scalastyle:ignore null
))
}

private def outputWithoutResponseError(output: Column): Column = {
val outputType = postprocessedOutputType
F.when(output.isNotNull, F.struct(
output.getField("statistics").as("statistics"),
output.getField("documents").as("documents"),
F.lit(null).cast(outputType("errors").dataType).as("errors"), // scalastyle:ignore null
output.getField("modelVersion").as("modelVersion")
)).otherwise(F.lit(null).cast(outputType)) // scalastyle:ignore null
}

private def moveResponseErrorsToErrorCol(
dataset: org.apache.spark.sql.Dataset[_]): org.apache.spark.sql.DataFrame = {
val df = dataset.toDF
val output = F.col(getOutputCol)
val responseError = output.getField("errors")
df.withColumn(getErrorCol, F.coalesce(F.col(getErrorCol), responseErrorToErrorCol(responseError)))
.withColumn(getOutputCol, F.when(responseError.isNotNull, outputWithoutResponseError(output)).otherwise(output))
}

override protected def getInternalTransformer(schema: StructType): PipelineModel = {

val batcher = if (shouldAutoBatch(schema)) {
Expand All @@ -293,8 +323,14 @@ class AnalyzeText(override val uid: String) extends CognitiveServicesBase(uid)
None
}

val moveResponseErrors = if (shouldAutoBatch(schema)) {
Some(Lambda(moveResponseErrorsToErrorCol _).setTransformSchema((schema: StructType) => schema))
} else {
None
}

NamespaceInjections.pipelineModel(
Array(batcher, Some(pipe), Some(postprocess), flatten).flatten
Array(batcher, Some(pipe), Some(postprocess), flatten, moveResponseErrors).flatten
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,102 @@

package com.microsoft.azure.synapse.ml.services.language

import com.microsoft.azure.synapse.ml.core.test.base.TestBase
import com.microsoft.azure.synapse.ml.services.text.{SentimentAssessment, TextEndpoint}
import com.microsoft.azure.synapse.ml.core.test.fuzzing.{TestObject, TransformerFuzzing}
import com.microsoft.azure.synapse.ml.io.http.{HTTPRequestData, HTTPResponseData, HTTPSchema}
import org.apache.http.impl.client.CloseableHttpClient
import org.apache.spark.ml.util.MLReadable
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions.{col, flatten, map}
import org.scalactic.{Equality, TolerantNumerics}

object AnalyzeTextErrorRoutingTestData extends Serializable {
val ResponseWithDocumentError: String =
"""{
| "kind": "PiiEntityRecognition",
| "results": {
| "documents": [
| {
| "id": "0",
| "redactedText": "My SSN is ***********",
| "entities": [],
| "warnings": [],
| "statistics": null
| }
| ],
| "errors": [
| {
| "id": "1",
| "error": {
| "code": "InvalidArgument",
| "message": "Document exceeds the service character limit.",
| "target": "documents.1",
| "details": null,
| "innererror": {
| "code": "InvalidDocument",
| "innerError": "DocumentTooLong"
| }
| }
| }
| ],
| "modelVersion": "test",
| "statistics": {
| "documentsCount": 2,
| "validDocumentsCount": 1,
| "erroneousDocumentsCount": 1,
| "transactionsCount": 2
| }
| }
|}""".stripMargin

def okResponseHandler(
client: CloseableHttpClient,
request: HTTPRequestData): HTTPResponseData = {
HTTPSchema.stringToResponse(ResponseWithDocumentError, 200, "OK")
}
}

class AnalyzeTextErrorRoutingSuite extends TestBase {
import spark.implicits._

test("AnalyzeText moves document-level 200 response errors to errorCol") {
val model = new AnalyzeText()
.setSubscriptionKey("unused")
.setLocation("eastus")
.setTextCol("text")
.setLanguage("en")
.setKind("PiiEntityRecognition")
.setOutputCol("response")
.setErrorCol("error")
.setHandler(AnalyzeTextErrorRoutingTestData.okResponseHandler _)

val rows = model.transform(Seq("valid text", "too long").toDF("text").coalesce(1))
.select("response", "error")
.collect()

assert(rows.length == 2)
val (failedRows, successRows) = rows.partition(row => row.getAs[Row]("error") != null)
assert(successRows.length == 1)
assert(failedRows.length == 1)

val successResponse = successRows.head.getAs[Row]("response")
assert(successResponse.getAs[Row]("documents") != null)
assert(successResponse.getAs[Row]("errors") == null)

val failedResponse = failedRows.head.getAs[Row]("response")
assert(failedResponse.getAs[Row]("documents") == null)
assert(failedResponse.getAs[Row]("errors") == null)

val error = failedRows.head.getAs[Row]("error")
assert(error != null)
val errorResponse = error.getAs[String]("response")
assert(errorResponse.contains("InvalidArgument"))
assert(errorResponse.contains("Document exceeds the service character limit."))
assert(error.getAs[Row]("status") == null)
}
}

class EntityLinkingSuite extends TransformerFuzzing[AnalyzeText] with TextEndpoint {
override val compareDataInSerializationTest: Boolean = false

Expand Down
Loading