diff --git a/cognitive/src/main/python/synapse/ml/services/openai/OpenAIDefaults.py b/cognitive/src/main/python/synapse/ml/services/openai/OpenAIDefaults.py index 28ff648ee90..69ffa656a57 100644 --- a/cognitive/src/main/python/synapse/ml/services/openai/OpenAIDefaults.py +++ b/cognitive/src/main/python/synapse/ml/services/openai/OpenAIDefaults.py @@ -137,3 +137,43 @@ def get_api_type(self): def reset_api_type(self): self.defaults.resetApiType() + + def set_api_timeout(self, timeout): + timeout_float = float(timeout) + if timeout_float <= 0: + raise ValueError( + f"API timeout must be greater than 0, got: {timeout_float}" + ) + self.defaults.setApiTimeout(timeout_float) + + def get_api_timeout(self): + return getOption(self.defaults.getApiTimeout()) + + def reset_api_timeout(self): + self.defaults.resetApiTimeout() + + def set_connection_timeout(self, timeout): + timeout_float = float(timeout) + if timeout_float <= 0: + raise ValueError( + f"Connection timeout must be greater than 0, got: {timeout_float}" + ) + self.defaults.setConnectionTimeout(timeout_float) + + def get_connection_timeout(self): + return getOption(self.defaults.getConnectionTimeout()) + + def reset_connection_timeout(self): + self.defaults.resetConnectionTimeout() + + def set_timeout(self, timeout): + timeout_float = float(timeout) + if timeout_float <= 0: + raise ValueError(f"Timeout must be greater than 0, got: {timeout_float}") + self.defaults.setTimeout(timeout_float) + + def get_timeout(self): + return getOption(self.defaults.getTimeout()) + + def reset_timeout(self): + self.defaults.resetTimeout() diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala index cdf63544236..cf9721a78ca 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala @@ -579,18 +579,22 @@ abstract class CognitiveServicesBaseNoHandler(val uid: String) extends Transform case l => l } + val baseTransformer = new SimpleHTTPTransformer() + .setInputCol(dynamicParamColName) + .setOutputCol(getOutputCol) + .setInputParser(getInternalInputParser(schema)) + .setOutputParser(getInternalOutputParser(schema)) + .setHandler(handlingFunc _) + .setConcurrency(getConcurrency) + .setConcurrentTimeout(get(concurrentTimeout)) + .setApiTimeout(getApiTimeout) + .setConnectionTimeout(getConnectionTimeout) + .setErrorCol(getErrorCol) + val transformer = get(timeout).map(baseTransformer.setTimeout).getOrElse(baseTransformer) + val stages = Array( Lambda(_.withColumn(dynamicParamColName, struct(dynamicParamCols: _*))), - new SimpleHTTPTransformer() - .setInputCol(dynamicParamColName) - .setOutputCol(getOutputCol) - .setInputParser(getInternalInputParser(schema)) - .setOutputParser(getInternalOutputParser(schema)) - .setHandler(handlingFunc _) - .setConcurrency(getConcurrency) - .setConcurrentTimeout(get(concurrentTimeout)) - .setTimeout(getTimeout) - .setErrorCol(getErrorCol), + transformer, new DropColumns().setCol(dynamicParamColName) ) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala index c50509ed5a9..58960ca9616 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala @@ -108,6 +108,9 @@ case object OpenAITopPKey extends GlobalKey[Either[Double, String]] case object OpenAIVerbosityKey extends GlobalKey[Either[String, String]] case object OpenAIReasoningEffortKey extends GlobalKey[Either[String, String]] case object OpenAIApiTypeKey extends GlobalKey[String] +case object OpenAIApiTimeoutKey extends GlobalKey[Double] +case object OpenAIConnectionTimeoutKey extends GlobalKey[Double] +case object OpenAITimeoutKey extends GlobalKey[Double] // scalastyle:off number.of.methods trait HasOpenAITextParams extends HasOpenAISharedParams { @@ -412,7 +415,7 @@ trait HasTextOutput { abstract class OpenAIServicesBase(override val uid: String) extends CognitiveServicesBase(uid: String) with HasOpenAISharedParams with OpenAIFabricSetting { - setDefault(timeout -> 360.0) + setDefault(apiTimeout -> 600.0) private def usingDefaultOpenAIEndpoint(): Boolean = { getUrl == FabricClient.MLWorkloadEndpointML + "/cognitive/openai/" diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala index 8d63032898a..99ff6b1aa1d 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala @@ -161,6 +161,45 @@ object OpenAIDefaults { GlobalParams.resetGlobalParam(OpenAIApiTypeKey) } + def setApiTimeout(v: Double): Unit = { + require(v > 0, s"API timeout must be greater than 0, got: $v") + GlobalParams.setGlobalParam(OpenAIApiTimeoutKey, v) + } + + def getApiTimeout: Option[Double] = { + GlobalParams.getGlobalParam(OpenAIApiTimeoutKey) + } + + def resetApiTimeout(): Unit = { + GlobalParams.resetGlobalParam(OpenAIApiTimeoutKey) + } + + def setConnectionTimeout(v: Double): Unit = { + require(v > 0, s"Connection timeout must be greater than 0, got: $v") + GlobalParams.setGlobalParam(OpenAIConnectionTimeoutKey, v) + } + + def getConnectionTimeout: Option[Double] = { + GlobalParams.getGlobalParam(OpenAIConnectionTimeoutKey) + } + + def resetConnectionTimeout(): Unit = { + GlobalParams.resetGlobalParam(OpenAIConnectionTimeoutKey) + } + + def setTimeout(v: Double): Unit = { + require(v > 0, s"Timeout must be greater than 0, got: $v") + GlobalParams.setGlobalParam(OpenAITimeoutKey, v) + } + + def getTimeout: Option[Double] = { + GlobalParams.getGlobalParam(OpenAITimeoutKey) + } + + def resetTimeout(): Unit = { + GlobalParams.resetGlobalParam(OpenAITimeoutKey) + } + private def extractLeft[T](optEither: Option[Either[T, String]]): Option[T] = { optEither match { case Some(Left(v)) => Some(v) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbedding.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbedding.scala index e980507066d..cfe1530af79 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbedding.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbedding.scala @@ -28,6 +28,10 @@ class OpenAIEmbedding (override val uid: String) extends OpenAIServicesBase(uid) with HasReturnUsage { logClass(FeatureNames.AiServices.OpenAI) + GlobalParams.registerParam(apiTimeout, OpenAIApiTimeoutKey) + GlobalParams.registerParam(connectionTimeout, OpenAIConnectionTimeoutKey) + GlobalParams.registerParam(timeout, OpenAITimeoutKey) + def this() = this(Identifiable.randomUID("OpenAIEmbedding")) def urlPath: String = "" diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala index 8cca1880b4e..6e1425a62ee 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala @@ -49,6 +49,10 @@ class OpenAIPrompt(override val uid: String) extends Transformer logClass(FeatureNames.AiServices.OpenAI) + GlobalParams.registerParam(apiTimeout, OpenAIApiTimeoutKey) + GlobalParams.registerParam(connectionTimeout, OpenAIConnectionTimeoutKey) + GlobalParams.registerParam(timeout, OpenAITimeoutKey) + def this() = this(Identifiable.randomUID("OpenAIPrompt")) override def copy(extra: ParamMap): Transformer = defaultCopy(extra) @@ -176,7 +180,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer systemPrompt -> defaultSystemPrompt, apiType -> "chat_completions", columnTypes -> Map.empty, - timeout -> 360.0 + apiTimeout -> 600.0 ) override def setCustomServiceName(v: String): this.type = { diff --git a/cognitive/src/test/python/synapsemltest/services/openai/test_OpenAIDefaults.py b/cognitive/src/test/python/synapsemltest/services/openai/test_OpenAIDefaults.py index df96c8e908d..1eb9fc14fc3 100644 --- a/cognitive/src/test/python/synapsemltest/services/openai/test_OpenAIDefaults.py +++ b/cognitive/src/test/python/synapsemltest/services/openai/test_OpenAIDefaults.py @@ -30,6 +30,9 @@ def test_setters_and_getters(self): defaults.set_api_version("2024-05-01-preview") defaults.set_model("grok-3-mini") defaults.set_embedding_deployment_name("text-embedding-ada-002") + defaults.set_api_timeout(600.0) + defaults.set_connection_timeout(5.0) + defaults.set_timeout(120.0) self.assertEqual(defaults.get_deployment_name(), "Bing Bong") self.assertEqual(defaults.get_subscription_key(), "SubKey") @@ -42,6 +45,9 @@ def test_setters_and_getters(self): self.assertEqual( defaults.get_embedding_deployment_name(), "text-embedding-ada-002" ) + self.assertEqual(defaults.get_api_timeout(), 600.0) + self.assertEqual(defaults.get_connection_timeout(), 5.0) + self.assertEqual(defaults.get_timeout(), 120.0) def test_resetters(self): defaults = OpenAIDefaults() @@ -55,6 +61,9 @@ def test_resetters(self): defaults.set_api_version("2024-05-01-preview") defaults.set_model("grok-3-mini") defaults.set_embedding_deployment_name("text-embedding-ada-002") + defaults.set_api_timeout(600.0) + defaults.set_connection_timeout(5.0) + defaults.set_timeout(120.0) self.assertEqual(defaults.get_deployment_name(), "Bing Bong") self.assertEqual(defaults.get_subscription_key(), "SubKey") @@ -67,6 +76,9 @@ def test_resetters(self): self.assertEqual( defaults.get_embedding_deployment_name(), "text-embedding-ada-002" ) + self.assertEqual(defaults.get_api_timeout(), 600.0) + self.assertEqual(defaults.get_connection_timeout(), 5.0) + self.assertEqual(defaults.get_timeout(), 120.0) defaults.reset_deployment_name() defaults.reset_subscription_key() @@ -77,6 +89,9 @@ def test_resetters(self): defaults.reset_api_version() defaults.reset_model() defaults.reset_embedding_deployment_name() + defaults.reset_api_timeout() + defaults.reset_connection_timeout() + defaults.reset_timeout() self.assertEqual(defaults.get_deployment_name(), None) self.assertEqual(defaults.get_subscription_key(), None) @@ -87,6 +102,9 @@ def test_resetters(self): self.assertEqual(defaults.get_api_version(), None) self.assertEqual(defaults.get_model(), None) self.assertEqual(defaults.get_embedding_deployment_name(), None) + self.assertEqual(defaults.get_api_timeout(), None) + self.assertEqual(defaults.get_connection_timeout(), None) + self.assertEqual(defaults.get_timeout(), None) def test_two_defaults(self): defaults = OpenAIDefaults() @@ -168,6 +186,28 @@ def test_parameter_validation(self): with self.assertRaises(ValueError): defaults.set_top_p(1.1) + # Test valid timeout values + defaults.set_api_timeout(1.0) + defaults.set_api_timeout(600.0) + defaults.set_connection_timeout(1.0) + defaults.set_connection_timeout(5.0) + defaults.set_timeout(60.0) + defaults.set_timeout(120.0) + + # Test invalid timeout values (must be > 0) + with self.assertRaises(ValueError): + defaults.set_api_timeout(0.0) + with self.assertRaises(ValueError): + defaults.set_api_timeout(-1.0) + with self.assertRaises(ValueError): + defaults.set_connection_timeout(0.0) + with self.assertRaises(ValueError): + defaults.set_connection_timeout(-1.0) + with self.assertRaises(ValueError): + defaults.set_timeout(0.0) + with self.assertRaises(ValueError): + defaults.set_timeout(-1.0) + class TestResponseFormatJsonSchema(unittest.TestCase): def setUp(self): diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaultsSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaultsSuite.scala index d226ebe64e3..55b7f5c466a 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaultsSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaultsSuite.scala @@ -88,6 +88,9 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey { OpenAIDefaults.setVerbosity("medium") OpenAIDefaults.setReasoningEffort("medium") OpenAIDefaults.setApiType("responses") + OpenAIDefaults.setApiTimeout(600.0) + OpenAIDefaults.setConnectionTimeout(5.0) + OpenAIDefaults.setTimeout(120.0) assert(OpenAIDefaults.getDeploymentName.contains(deploymentName)) assert(OpenAIDefaults.getSubscriptionKey.contains(openAIAPIKey)) @@ -101,6 +104,9 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey { assert(OpenAIDefaults.getVerbosity.contains("medium")) assert(OpenAIDefaults.getReasoningEffort.contains("medium")) assert(OpenAIDefaults.getApiType.contains("responses")) + assert(OpenAIDefaults.getApiTimeout.contains(600.0)) + assert(OpenAIDefaults.getConnectionTimeout.contains(5.0)) + assert(OpenAIDefaults.getTimeout.contains(120.0)) } test("Test Resetters") { @@ -116,6 +122,9 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey { OpenAIDefaults.setVerbosity("medium") OpenAIDefaults.setReasoningEffort("medium") OpenAIDefaults.setApiType("responses") + OpenAIDefaults.setApiTimeout(600.0) + OpenAIDefaults.setConnectionTimeout(5.0) + OpenAIDefaults.setTimeout(120.0) OpenAIDefaults.resetDeploymentName() OpenAIDefaults.resetSubscriptionKey() @@ -129,6 +138,9 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey { OpenAIDefaults.resetVerbosity() OpenAIDefaults.resetReasoningEffort() OpenAIDefaults.resetApiType() + OpenAIDefaults.resetApiTimeout() + OpenAIDefaults.resetConnectionTimeout() + OpenAIDefaults.resetTimeout() assert(OpenAIDefaults.getDeploymentName.isEmpty) assert(OpenAIDefaults.getSubscriptionKey.isEmpty) @@ -142,6 +154,9 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey { assert(OpenAIDefaults.getVerbosity.isEmpty) assert(OpenAIDefaults.getReasoningEffort.isEmpty) assert(OpenAIDefaults.getApiType.isEmpty) + assert(OpenAIDefaults.getApiTimeout.isEmpty) + assert(OpenAIDefaults.getConnectionTimeout.isEmpty) + assert(OpenAIDefaults.getTimeout.isEmpty) } test("Test Parameter Validation") { @@ -183,5 +198,33 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey { // Test reasoning effort values OpenAIDefaults.setReasoningEffort("low") OpenAIDefaults.setReasoningEffort("anything") + + // Test valid timeout values + OpenAIDefaults.setApiTimeout(1.0) + OpenAIDefaults.setApiTimeout(600.0) + OpenAIDefaults.setConnectionTimeout(1.0) + OpenAIDefaults.setConnectionTimeout(5.0) + OpenAIDefaults.setTimeout(60.0) + OpenAIDefaults.setTimeout(120.0) + + // Test invalid timeout values (must be > 0) + assertThrows[IllegalArgumentException] { + OpenAIDefaults.setApiTimeout(0.0) + } + assertThrows[IllegalArgumentException] { + OpenAIDefaults.setApiTimeout(-1.0) + } + assertThrows[IllegalArgumentException] { + OpenAIDefaults.setConnectionTimeout(0.0) + } + assertThrows[IllegalArgumentException] { + OpenAIDefaults.setConnectionTimeout(-1.0) + } + assertThrows[IllegalArgumentException] { + OpenAIDefaults.setTimeout(0.0) + } + assertThrows[IllegalArgumentException] { + OpenAIDefaults.setTimeout(-1.0) + } } } diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbeddingsSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbeddingSuite.scala similarity index 71% rename from cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbeddingsSuite.scala rename to cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbeddingSuite.scala index beeb11afde7..552265121d3 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbeddingsSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbeddingSuite.scala @@ -11,7 +11,7 @@ import org.apache.spark.ml.linalg.{Vector, SQLDataTypes} import org.apache.spark.sql.types.StructType import org.scalactic.Equality -class OpenAIEmbeddingsSuite extends TransformerFuzzing[OpenAIEmbedding] with OpenAIAPIKey with Flaky { +class OpenAIEmbeddingSuite extends TransformerFuzzing[OpenAIEmbedding] with OpenAIAPIKey with Flaky { import spark.implicits._ @@ -203,6 +203,86 @@ class OpenAIEmbeddingsSuite extends TransformerFuzzing[OpenAIEmbedding] with Ope assert(results(2).getAs[Row]("null_test_usage") != null) } + test("Timeout Configuration") { + // Test that timeout parameters are properly set and retrieved + val embeddingWithTimeout = new OpenAIEmbedding() + .setSubscriptionKey(openAIAPIKey) + .setCustomServiceName(openAIServiceName) + .setDeploymentName("text-embedding-ada-002") + .setTextCol("text") + .setOutputCol("out") + .setApiTimeout(300.0) + .setConnectionTimeout(10.0) + .setTimeout(60.0) + + assert(embeddingWithTimeout.getApiTimeout == 300.0) + assert(embeddingWithTimeout.getConnectionTimeout == 10.0) + assert(embeddingWithTimeout.getTimeout == 60.0) + } + + test("Short Timeout Returns Timeout Error") { + val embeddingWithShortTimeout = new OpenAIEmbedding() + .setSubscriptionKey(openAIAPIKey) + .setCustomServiceName(openAIServiceName) + .setDeploymentName("text-embedding-ada-002") + .setTextCol("text") + .setOutputCol("out") + .setTimeout(0.001) // Very short timeout to force timeout error + + val errorCol = embeddingWithShortTimeout.getErrorCol + val results = embeddingWithShortTimeout + .transform(df) + .select("out", errorCol) + .collect() + + // All rows should have timeout errors due to very short timeout + results.foreach { row => + val errorRow = row.getAs[Row](1) + assert(errorRow != null, "Should have error due to timeout") + val errorResponse = errorRow.getAs[String]("response") + assert(errorResponse.contains("exceeded the time limit") || + errorResponse.contains("timeout"), + s"Error should mention timeout, got: $errorResponse") + } + } + + test("Embedding uses global timeout defaults from OpenAIDefaults") { + val originalApiTimeout = OpenAIDefaults.getApiTimeout + val originalConnectionTimeout = OpenAIDefaults.getConnectionTimeout + val originalTimeout = OpenAIDefaults.getTimeout + + try { + OpenAIDefaults.setApiTimeout(350.0) + OpenAIDefaults.setConnectionTimeout(15.0) + OpenAIDefaults.setTimeout(120.0) + + val e = new OpenAIEmbedding() + .setSubscriptionKey(openAIAPIKey) + .setCustomServiceName(openAIServiceName) + .setDeploymentName("text-embedding-ada-002") + .setTextCol("text") + .setOutputCol("out") + + // The embedding should work with global defaults applied + val results = e.transform(df.limit(1)).collect() + assert(results.length == 1) + assert(results(0).getAs[Vector]("out").size > 0) + } finally { + // Reset global defaults to avoid cross-test contamination + originalApiTimeout match { + case Some(v) => OpenAIDefaults.setApiTimeout(v) + case None => OpenAIDefaults.resetApiTimeout() + } + originalConnectionTimeout match { + case Some(v) => OpenAIDefaults.setConnectionTimeout(v) + case None => OpenAIDefaults.resetConnectionTimeout() + } + originalTimeout match { + case Some(v) => OpenAIDefaults.setTimeout(v) + case None => OpenAIDefaults.resetTimeout() + } + } + } override def testObjects(): Seq[TestObject[OpenAIEmbedding]] = Seq(new TestObject(embedding, df)) diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala index b902d7afb08..346b38328eb 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala @@ -398,6 +398,79 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK )) } + lazy val longInputDf: DataFrame = Seq( + ("apple", "fruits"), + (null, null), // scalastyle:ignore null + ("Flying on the weekends is a lot of fun! " * 1000, "travel"), + ("Flying is a lot of fun! " * 10000, "travel") + ).toDF("text", "category") + + test("Long Input Handling") { + val results = prompt + .setPromptTemplate("Summarize the following text in 10 words or less: {text}") + .setTimeout(120.0) + .transform(longInputDf) + .select("outParsed", prompt.getErrorCol) + .collect() + + // Row 0: "apple" - normal input, should have valid output + assert(Option(results(0).get(0)).isDefined) + + // Row 1: null input should return null output + assert(results(1).get(0) == null) + + // Row 2: 1000 repetitions - may succeed or fail depending on model limits + val row2HasOutput = Option(results(2).get(0)).isDefined + val row2HasError = Option(results(2).getAs[Row](1)).isDefined + assert(row2HasOutput || row2HasError, "Row 2 should have either output or error") + + // Row 3: 10000 repetitions - possible to exceed token limits + val row3HasOutput = Option(results(3).get(0)).isDefined + val row3HasError = Option(results(3).getAs[Row](1)).isDefined + assert(row3HasOutput || row3HasError, "Row 3 should have either output or error") + } + + test("Timeout Configuration") { + // Test that timeout parameters are properly set and retrieved + val promptWithTimeout = new OpenAIPrompt() + .setSubscriptionKey(openAIAPIKey) + .setDeploymentName(deploymentName) + .setCustomServiceName(openAIServiceName) + .setOutputCol("outParsed") + .setApiTimeout(300.0) + .setConnectionTimeout(10.0) + .setTimeout(60.0) + + assert(promptWithTimeout.getApiTimeout == 300.0) + assert(promptWithTimeout.getConnectionTimeout == 10.0) + assert(promptWithTimeout.getTimeout == 60.0) + } + + test("Short Timeout Returns Timeout Error") { + val promptWithShortTimeout = new OpenAIPrompt() + .setSubscriptionKey(openAIAPIKey) + .setDeploymentName(deploymentName) + .setCustomServiceName(openAIServiceName) + .setOutputCol("outParsed") + .setPromptTemplate("List 5 {category}") + .setTimeout(0.001) // Very short timeout to force timeout error + + val results = promptWithShortTimeout + .transform(df) + .select("outParsed", promptWithShortTimeout.getErrorCol) + .collect() + + // All rows should have timeout errors due to very short timeout + results.foreach { row => + val errorRow = row.getAs[Row](1) + assert(errorRow != null, "Should have error due to timeout") + val errorResponse = errorRow.getAs[String]("response") + assert(errorResponse.contains("exceeded the time limit") || + errorResponse.contains("timeout"), + s"Error should mention timeout, got: $errorResponse") + } + } + override def assertDFEq(df1: DataFrame, df2: DataFrame)(implicit eq: Equality[DataFrame]): Unit = { super.assertDFEq(df1.drop("out", "outParsed"), df2.drop("out", "outParsed"))(eq) } diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/Clients.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/Clients.scala index c192a0fa81f..f4b6978b0eb 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/Clients.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/Clients.scala @@ -21,8 +21,11 @@ private[ml] trait BaseClient { def this(response: Option[ResponseType]) = this(response, None) } - case class RequestWithContext(request: Option[RequestType], context: Context) { - def this(request: Option[RequestType]) = this(request, None) + case class RequestWithContext(request: Option[RequestType], + context: Context, + precomputedResponse: Option[ResponseType] = None) { + def this(request: Option[RequestType]) = this(request, None, None) + def this(request: Option[RequestType], context: Context) = this(request, context, None) } protected lazy val logger: Logger = LogManager.getLogger("BaseClient") diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPClients.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPClients.scala index 175d238dc09..79851048ea1 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPClients.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPClients.scala @@ -32,10 +32,11 @@ private[ml] trait HTTPClient extends BaseClient override type RequestType = HTTPRequestData protected val requestTimeout: Int + protected val connectionTimeout: Int protected val requestConfig: RequestConfig = RequestConfig.custom() - .setConnectTimeout(requestTimeout) - .setConnectionRequestTimeout(requestTimeout) + .setConnectTimeout(connectionTimeout) + .setConnectionRequestTimeout(connectionTimeout) .setSocketTimeout(requestTimeout) .build() @@ -55,9 +56,15 @@ private[ml] trait HTTPClient extends BaseClient } protected def sendRequestWithContext(request: RequestWithContext): ResponseWithContext = { - request.request.map(req => - ResponseWithContext(Some(handle(internalClient, req)), request.context) - ).getOrElse(ResponseWithContext(None, request.context)) + // If there's a precomputed response (e.g., timeout), return it without making the HTTP call + request.precomputedResponse match { + case Some(response) => + ResponseWithContext(Some(response), request.context) + case None => + request.request.map(req => + ResponseWithContext(Some(handle(internalClient, req)), request.context) + ).getOrElse(ResponseWithContext(None, request.context)) + } } } @@ -184,7 +191,8 @@ object HandlingUtils extends SparkLogging { class AsyncHTTPClient(val handler: HandlingUtils.HandlerFunc, override val concurrency: Int, override val timeout: Duration, - val requestTimeout: Int) + val requestTimeout: Int, + val connectionTimeout: Int) (override implicit val ec: ExecutionContext) extends AsyncClient(concurrency, timeout)(ec) with HTTPClient { override def handle(client: CloseableHttpClient, @@ -193,7 +201,9 @@ class AsyncHTTPClient(val handler: HandlingUtils.HandlerFunc, } } -class SingleThreadedHTTPClient(val handler: HandlingUtils.HandlerFunc, val requestTimeout: Int) +class SingleThreadedHTTPClient(val handler: HandlingUtils.HandlerFunc, + val requestTimeout: Int, + val connectionTimeout: Int) extends HTTPClient with SingleThreadedClient { override def handle(client: CloseableHttpClient, request: HTTPRequestData): HTTPResponseData = blocking { diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPSchema.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPSchema.scala index 46ffbfc237c..ce3a540687e 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPSchema.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPSchema.scala @@ -290,7 +290,7 @@ object HTTPSchema { HTTPResponseData( Array(), Some(stringToEntity(x)), - StatusLineData(null, code, reason), + StatusLineData(ProtocolVersionData("HTTP", 1, 1), code, reason), "en") } @@ -304,7 +304,7 @@ object HTTPSchema { HTTPResponseData( Array(), None, - StatusLineData(null, code, reason), + StatusLineData(ProtocolVersionData("HTTP", 1, 1), code, reason), "en") } @@ -318,7 +318,7 @@ object HTTPSchema { HTTPResponseData( Array(), Some(binaryToEntity(x)), - StatusLineData(null, code, reason), + StatusLineData(ProtocolVersionData("HTTP", 1, 1), code, reason), "en") } diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPTransformer.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPTransformer.scala index 4fba4c03852..75108a712a9 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPTransformer.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPTransformer.scala @@ -51,8 +51,30 @@ trait ConcurrencyParams extends Wrappable { /** @group setParam */ def setConcurrency(value: Int): this.type = set(concurrency, value) + val apiTimeout: Param[Double] = new DoubleParam( + this, "apiTimeout", "number of seconds to wait for a response from the API") + + /** @group getParam */ + def getApiTimeout: Double = $(apiTimeout) + + /** @group setParam */ + def setApiTimeout(value: Double): this.type = set(apiTimeout, value) + + val connectionTimeout: Param[Double] = new DoubleParam( + this, "connectionTimeout", "number of seconds to wait for establishing a connection") + + /** @group getParam */ + def getConnectionTimeout: Double = $(connectionTimeout) + + /** @group setParam */ + def setConnectionTimeout(value: Double): this.type = set(connectionTimeout, value) + val timeout: Param[Double] = new DoubleParam( - this, "timeout", "number of seconds to wait before closing the connection") + this, "timeout", + "number of seconds for the entire DataFrame transformation to complete, " + + "measured from the start of the transform operation; " + + "rows processed after this timeout will receive HTTP 408 responses without making API calls" + ) /** @group getParam */ def getTimeout: Double = $(timeout) @@ -73,7 +95,7 @@ trait ConcurrencyParams extends Wrappable { case Some(v) => setConcurrentTimeout(v) case None => clear(concurrentTimeout) } - setDefault(concurrency -> 1, timeout -> 60.0) + setDefault(concurrency -> 1, apiTimeout -> 60.0, connectionTimeout -> 5.0) } case object URLKey extends GlobalKey[String] @@ -106,16 +128,32 @@ class HTTPTransformer(val uid: String) val clientHolder = SharedVariable { getConcurrency match { - case 1 => new SingleThreadedHTTPClient(getHandler, (getTimeout * 1000).toInt) + case 1 => new SingleThreadedHTTPClient( + getHandler, + (getApiTimeout * 1000).toInt, + (getConnectionTimeout * 1000).toInt) case n if n > 1 => val dur = get(concurrentTimeout) .map(ct => Duration.fromNanos((ct * math.pow(10, 9)).toLong)) //scalastyle:ignore magic.number .getOrElse(Duration.Inf) val ec = ExecutionContext.global - new AsyncHTTPClient(getHandler, n, dur, (getTimeout * 1000).toInt)(ec) + new AsyncHTTPClient( + getHandler, + n, + dur, + (getApiTimeout * 1000).toInt, + (getConnectionTimeout * 1000).toInt)(ec) } } + private def createTimeoutResponse(timeoutSeconds: Double): HTTPResponseData = { + HTTPSchema.stringToResponse( + f"The operation exceeded the time limit of $timeoutSeconds%.1f seconds. " + + "Fix: increase value of timeout or reset it for no limit.", + 408, //scalastyle:ignore magic.number HTTP_REQUEST_TIMEOUT + "Request Timeout") + } + /** @param dataset - The input dataset, to be transformed * @return The DataFrame that results from column selection */ @@ -126,16 +164,51 @@ class HTTPTransformer(val uid: String) val colIndex = df.schema.fieldNames.indexOf(getInputCol) val fromRow = HTTPRequestData.makeFromRowConverter val toRow = HTTPResponseData.makeToRowConverter + + val timeoutMs = get(timeout).map(t => (t * 1000).toLong) + val startTime = timeoutMs.map(_ => System.currentTimeMillis()) + val startTimeBroadcast = startTime.map(df.sparkSession.sparkContext.broadcast(_)) + + val timeoutSec = get(timeout) + df.mapPartitions { it => if (!it.hasNext) { Iterator() } else { - val c = clientHolder.get - val responsesWithContext = c.sendRequestsWithContext(it.map { row => - c.RequestWithContext(Option(row.getStruct(colIndex)).map(fromRow), Some(row)) - }) - responsesWithContext.map { rwc => - Row.fromSeq(rwc.context.get.asInstanceOf[Row].toSeq :+ rwc.response.flatMap(Option.apply).map(toRow).orNull) + // Early timeout check before creating HTTP client to avoid network errors + val isAlreadyTimedOut = (timeoutMs, startTimeBroadcast) match { + case (Some(tm), Some(startBroadcast)) => + System.currentTimeMillis() - startBroadcast.value > tm + case _ => false + } + + if (isAlreadyTimedOut) { + // Return timeout responses for all rows without creating HTTP client + it.map { row => + Row.fromSeq(row.toSeq :+ toRow(createTimeoutResponse(timeoutSec.get))) + } + } else { + val c = clientHolder.get + + val responsesWithContext = c.sendRequestsWithContext(it.map { row => + // Check if timeout has been exceeded + val isTimedOut = (timeoutMs, startTimeBroadcast) match { + case (Some(tm), Some(startBroadcast)) => + System.currentTimeMillis() - startBroadcast.value > tm + case _ => false + } + + if (isTimedOut) { + // Return a timeout response without making the API call + c.RequestWithContext(None, Some(row), Some(createTimeoutResponse(timeoutSec.get))) + } else { + c.RequestWithContext(Option(row.getStruct(colIndex)).map(fromRow), Some(row)) + } + }) + responsesWithContext.map { rwc => + Row.fromSeq(rwc.context.get.asInstanceOf[Row].toSeq :+ + rwc.response.flatMap(Option.apply).map(toRow).orNull) + } } } }(enc) diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/SimpleHTTPTransformer.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/SimpleHTTPTransformer.scala index e30c3a1970d..5f21edde250 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/SimpleHTTPTransformer.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/SimpleHTTPTransformer.scala @@ -125,13 +125,15 @@ class SimpleHTTPTransformer(val uid: String) .setInputCol(getInputCol) .setOutputCol(parsedInputCol)) - val client = Some(new HTTPTransformer() + val baseClient = new HTTPTransformer() .setHandler(getHandler) .setConcurrency(getConcurrency) .setConcurrentTimeout(get(concurrentTimeout)) - .setTimeout(getTimeout) + .setApiTimeout(getApiTimeout) + .setConnectionTimeout(getConnectionTimeout) .setInputCol(parsedInputCol) - .setOutputCol(unparsedOutputCol)) + .setOutputCol(unparsedOutputCol) + val client = Some(get(timeout).map(baseClient.setTimeout).getOrElse(baseClient)) val parseErrors = Some(Lambda(_ .withColumn(getErrorCol, ErrorUtils.addErrorUDF(col(unparsedOutputCol))) diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/io/split1/HTTPTransformerSuite.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/io/split1/HTTPTransformerSuite.scala index 3ecedf52382..a88f0a17518 100644 --- a/core/src/test/scala/com/microsoft/azure/synapse/ml/io/split1/HTTPTransformerSuite.scala +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/io/split1/HTTPTransformerSuite.scala @@ -7,10 +7,10 @@ import com.microsoft.azure.synapse.ml.core.env.StreamUtilities import com.microsoft.azure.synapse.ml.core.env.StreamUtilities.using import com.microsoft.azure.synapse.ml.core.test.base.TestBase import com.microsoft.azure.synapse.ml.core.test.fuzzing.{TestObject, TransformerFuzzing} -import com.microsoft.azure.synapse.ml.io.http.HTTPTransformer +import com.microsoft.azure.synapse.ml.io.http.{HTTPResponseData, HTTPTransformer, JSONInputParser} import com.sun.net.httpserver.{HttpExchange, HttpHandler, HttpServer} import org.apache.spark.ml.util.MLReadable -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.scalactic.Equality import java.net.{InetSocketAddress, ServerSocket} @@ -123,4 +123,90 @@ class HTTPTransformerSuite extends TransformerFuzzing[HTTPTransformer] } } + test("HTTPTransformer should have correct default timeout values") { + val transformer = new HTTPTransformer() + .setInputCol("parsedInput") + .setOutputCol("out") + + assert(transformer.getApiTimeout === 60.0) + assert(transformer.getConnectionTimeout === 5.0) + assert(transformer.getConcurrency === 1) + } + + test("HTTPTransformer should allow setting custom timeout values") { + val transformer = new HTTPTransformer() + .setInputCol("parsedInput") + .setOutputCol("out") + .setApiTimeout(120.0) + .setConnectionTimeout(10.0) + .setConcurrency(5) + + assert(transformer.getApiTimeout === 120.0) + assert(transformer.getConnectionTimeout === 10.0) + assert(transformer.getConcurrency === 5) + } + + test("HTTPTransformer should allow setting global operation timeout") { + val transformer = new HTTPTransformer() + .setInputCol("parsedInput") + .setOutputCol("out") + .setTimeout(30.0) + + assert(transformer.getTimeout === 30.0) + } + + test("HTTPTransformer global timeout returns 408 for timed out requests") { + import spark.implicits._ + + // Create a server that responds slowly + val slowServer = ServerUtils.createServiceOnFreePort("slow", handler = new com.sun.net.httpserver.HttpHandler { + override def handle(request: com.sun.net.httpserver.HttpExchange): Unit = { + Thread.sleep(5000) // 5 second delay + val response = "{\"result\": \"done\"}" + request.getResponseHeaders.add("Content-Type", "application/json") + request.sendResponseHeaders(200, response.length) + val os = request.getResponseBody + os.write(response.getBytes) + os.close() + } + }) + + try { + val slowUrl = s"http://localhost:${slowServer.getAddress.getPort}/slow" + + val df = sc.parallelize((1 to 3).map(Tuple1(_))).toDF("data") + val parsedDf = new JSONInputParser() + .setInputCol("data") + .setOutputCol("parsedInput") + .setUrl(slowUrl) + .transform(df) + + val transformer = new HTTPTransformer() + .setInputCol("parsedInput") + .setOutputCol("out") + .setTimeout(0.5) // 0.5 second timeout - should timeout before server responds + .setApiTimeout(10.0) // Long enough not to timeout on API level + + val results = transformer.transform(parsedDf).collect() + assert(results.length === 3) + + // Check that some requests got timeout response (HTTP 408) + val fromRow = HTTPResponseData.makeFromRowConverter + val timedOutResponses = results.filter { row => + val responseRow = row.getAs[Row]("out") + if (responseRow != null) { + val response = fromRow(responseRow) + response.statusLine.statusCode == 408 + } else { + false + } + } + + // At least some responses should be timed out (exact count depends on timing) + assert(timedOutResponses.nonEmpty || results.exists(_.getAs[Row]("out") != null)) + } finally { + slowServer.stop(0) + } + } + } diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/io/split1/SimpleHTTPTransformerSuite.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/io/split1/SimpleHTTPTransformerSuite.scala index 9ecaf832f8c..e709ccb8865 100644 --- a/core/src/test/scala/com/microsoft/azure/synapse/ml/io/split1/SimpleHTTPTransformerSuite.scala +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/io/split1/SimpleHTTPTransformerSuite.scala @@ -36,7 +36,7 @@ class SimpleHTTPTransformerSuite lazy val df2: DataFrame = sc.parallelize((1 to 5).map(Tuple1(_))).toDF("data") val results = simpleTransformer .setUrl(url + "/flaky") - .setTimeout(1) + .setApiTimeout(1) .transform(df2).collect assert(results.length == 5) } diff --git a/docs/Explore Algorithms/OpenAI/OpenAI.ipynb b/docs/Explore Algorithms/OpenAI/OpenAI.ipynb index 03b3646f943..db39d51f28f 100644 --- a/docs/Explore Algorithms/OpenAI/OpenAI.ipynb +++ b/docs/Explore Algorithms/OpenAI/OpenAI.ipynb @@ -257,6 +257,18 @@ ")" ] }, + { + "cell_type": "markdown", + "source": "### Configuring Timeouts\n\nSynapseML OpenAI connectors support configurable timeout parameters to handle long-running requests and prevent hanging operations:\n\n- **`api_timeout`**: Number of seconds to wait for a response from the API (default: 600 seconds / 10 minutes). This matches the OpenAI Python SDK default.\n- **`connection_timeout`**: Number of seconds to wait for establishing a connection (default: 5 seconds).\n- **`timeout`**: Global operation timeout in seconds. When set, the entire DataFrame transformation will stop making new API calls after this duration, returning HTTP 408 (Request Timeout) errors for any remaining rows. Useful for batch processing with time constraints.\n\nThese parameters can be set globally using `OpenAIDefaults` or on individual transformer instances.", + "metadata": {} + }, + { + "cell_type": "code", + "source": "from synapse.ml.services.openai import OpenAIChatCompletion\nfrom synapse.ml.services.openai.OpenAIDefaults import OpenAIDefaults\n\n# Set global timeout defaults\ndefaults = OpenAIDefaults()\ndefaults.set_deployment_name(deployment_name)\ndefaults.set_subscription_key(key)\ndefaults.set_URL(f\"https://{service_name}.openai.azure.com/\")\ndefaults.set_api_timeout(600.0) # 10 minutes API timeout (default)\ndefaults.set_connection_timeout(5.0) # 5 seconds connection timeout (default)\ndefaults.set_timeout(300.0) # 5 minutes global operation timeout\n\n# Transformers will automatically use these defaults\nchat_completion_with_timeouts = (\n OpenAIChatCompletion()\n .setMessagesCol(\"messages\")\n .setOutputCol(\"chat_completions\")\n .setErrorCol(\"chat_completions_error\")\n)\n\n# Alternatively, set timeouts directly on the transformer instance\nchat_completion_custom_timeout = (\n OpenAIChatCompletion()\n .setSubscriptionKey(key)\n .setDeploymentName(deployment_name)\n .setCustomServiceName(service_name)\n .setMessagesCol(\"messages\")\n .setOutputCol(\"chat_completions\")\n .setErrorCol(\"chat_completions_error\")\n .setApiTimeout(300.0) # 5 minutes API timeout\n .setConnectionTimeout(10.0) # 10 seconds connection timeout\n .setTimeout(120.0) # 2 minutes global operation timeout\n)\n\ndisplay(\n chat_completion_with_timeouts.transform(chat_df)\n .select(\"messages\", \"chat_completions.choices.message.content\")\n .show(truncate=False)\n)", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, { "attachments": {}, "cell_type": "markdown", @@ -920,4 +932,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} +} \ No newline at end of file