From 57825d3637526c1825ebbdaa0a06862a3ed1e028 Mon Sep 17 00:00:00 2001 From: Ranadeep Singh Date: Mon, 15 Dec 2025 19:45:27 -0800 Subject: [PATCH 1/8] Add API timeout and Connection tmeout similar to OpenAI python SDK --- .../ml/services/openai/OpenAIDefaults.py | 26 +++++++++++++++++++ .../ml/services/CognitiveServiceBase.scala | 1 + .../synapse/ml/services/openai/OpenAI.scala | 4 ++- .../ml/services/openai/OpenAIDefaults.scala | 26 +++++++++++++++++++ .../ml/services/openai/OpenAIPrompt.scala | 5 +++- .../synapse/ml/io/http/HTTPClients.scala | 12 ++++++--- .../synapse/ml/io/http/HTTPTransformer.scala | 23 +++++++++++++--- .../ml/io/http/SimpleHTTPTransformer.scala | 1 + 8 files changed, 89 insertions(+), 9 deletions(-) diff --git a/cognitive/src/main/python/synapse/ml/services/openai/OpenAIDefaults.py b/cognitive/src/main/python/synapse/ml/services/openai/OpenAIDefaults.py index 28ff648ee90..e40a9f5a654 100644 --- a/cognitive/src/main/python/synapse/ml/services/openai/OpenAIDefaults.py +++ b/cognitive/src/main/python/synapse/ml/services/openai/OpenAIDefaults.py @@ -137,3 +137,29 @@ def get_api_type(self): def reset_api_type(self): self.defaults.resetApiType() + + def set_timeout(self, timeout): + timeout_float = float(timeout) + if timeout_float <= 0: + raise ValueError(f"Timeout must be greater than 0, got: {timeout_float}") + self.defaults.setTimeout(timeout_float) + + def get_timeout(self): + return getOption(self.defaults.getTimeout()) + + def reset_timeout(self): + self.defaults.resetTimeout() + + def set_connection_timeout(self, timeout): + timeout_float = float(timeout) + if timeout_float <= 0: + raise ValueError( + f"Connection timeout must be greater than 0, got: {timeout_float}" + ) + self.defaults.setConnectionTimeout(timeout_float) + + def get_connection_timeout(self): + return getOption(self.defaults.getConnectionTimeout()) + + def reset_connection_timeout(self): + self.defaults.resetConnectionTimeout() diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala index cdf63544236..2c4f7a2d5e4 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala @@ -590,6 +590,7 @@ abstract class CognitiveServicesBaseNoHandler(val uid: String) extends Transform .setConcurrency(getConcurrency) .setConcurrentTimeout(get(concurrentTimeout)) .setTimeout(getTimeout) + .setConnectionTimeout(getConnectionTimeout) .setErrorCol(getErrorCol), new DropColumns().setCol(dynamicParamColName) ) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala index c50509ed5a9..f20f93e481d 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala @@ -108,6 +108,8 @@ case object OpenAITopPKey extends GlobalKey[Either[Double, String]] case object OpenAIVerbosityKey extends GlobalKey[Either[String, String]] case object OpenAIReasoningEffortKey extends GlobalKey[Either[String, String]] case object OpenAIApiTypeKey extends GlobalKey[String] +case object OpenAITimeoutKey extends GlobalKey[Double] +case object OpenAIConnectionTimeoutKey extends GlobalKey[Double] // scalastyle:off number.of.methods trait HasOpenAITextParams extends HasOpenAISharedParams { @@ -412,7 +414,7 @@ trait HasTextOutput { abstract class OpenAIServicesBase(override val uid: String) extends CognitiveServicesBase(uid: String) with HasOpenAISharedParams with OpenAIFabricSetting { - setDefault(timeout -> 360.0) + setDefault(timeout -> 600.0) private def usingDefaultOpenAIEndpoint(): Boolean = { getUrl == FabricClient.MLWorkloadEndpointML + "/cognitive/openai/" diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala index 8d63032898a..d11113708fb 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala @@ -161,6 +161,32 @@ object OpenAIDefaults { GlobalParams.resetGlobalParam(OpenAIApiTypeKey) } + def setTimeout(v: Double): Unit = { + require(v > 0, s"Timeout must be greater than 0, got: $v") + GlobalParams.setGlobalParam(OpenAITimeoutKey, v) + } + + def getTimeout: Option[Double] = { + GlobalParams.getGlobalParam(OpenAITimeoutKey) + } + + def resetTimeout(): Unit = { + GlobalParams.resetGlobalParam(OpenAITimeoutKey) + } + + def setConnectionTimeout(v: Double): Unit = { + require(v > 0, s"Connection timeout must be greater than 0, got: $v") + GlobalParams.setGlobalParam(OpenAIConnectionTimeoutKey, v) + } + + def getConnectionTimeout: Option[Double] = { + GlobalParams.getGlobalParam(OpenAIConnectionTimeoutKey) + } + + def resetConnectionTimeout(): Unit = { + GlobalParams.resetGlobalParam(OpenAIConnectionTimeoutKey) + } + private def extractLeft[T](optEither: Option[Either[T, String]]): Option[T] = { optEither match { case Some(Left(v)) => Some(v) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala index fc8fa15ff3a..555cad3943c 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala @@ -49,6 +49,9 @@ class OpenAIPrompt(override val uid: String) extends Transformer logClass(FeatureNames.AiServices.OpenAI) + GlobalParams.registerParam(timeout, OpenAITimeoutKey) + GlobalParams.registerParam(connectionTimeout, OpenAIConnectionTimeoutKey) + def this() = this(Identifiable.randomUID("OpenAIPrompt")) override def copy(extra: ParamMap): Transformer = defaultCopy(extra) @@ -176,7 +179,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer systemPrompt -> defaultSystemPrompt, apiType -> "chat_completions", columnTypes -> Map.empty, - timeout -> 360.0 + timeout -> 600.0 ) override def setCustomServiceName(v: String): this.type = { diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPClients.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPClients.scala index 175d238dc09..c2e47b4599d 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPClients.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPClients.scala @@ -32,10 +32,11 @@ private[ml] trait HTTPClient extends BaseClient override type RequestType = HTTPRequestData protected val requestTimeout: Int + protected val connectionTimeout: Int protected val requestConfig: RequestConfig = RequestConfig.custom() - .setConnectTimeout(requestTimeout) - .setConnectionRequestTimeout(requestTimeout) + .setConnectTimeout(connectionTimeout) + .setConnectionRequestTimeout(connectionTimeout) .setSocketTimeout(requestTimeout) .build() @@ -184,7 +185,8 @@ object HandlingUtils extends SparkLogging { class AsyncHTTPClient(val handler: HandlingUtils.HandlerFunc, override val concurrency: Int, override val timeout: Duration, - val requestTimeout: Int) + val requestTimeout: Int, + val connectionTimeout: Int) (override implicit val ec: ExecutionContext) extends AsyncClient(concurrency, timeout)(ec) with HTTPClient { override def handle(client: CloseableHttpClient, @@ -193,7 +195,9 @@ class AsyncHTTPClient(val handler: HandlingUtils.HandlerFunc, } } -class SingleThreadedHTTPClient(val handler: HandlingUtils.HandlerFunc, val requestTimeout: Int) +class SingleThreadedHTTPClient(val handler: HandlingUtils.HandlerFunc, + val requestTimeout: Int, + val connectionTimeout: Int) extends HTTPClient with SingleThreadedClient { override def handle(client: CloseableHttpClient, request: HTTPRequestData): HTTPResponseData = blocking { diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPTransformer.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPTransformer.scala index 4fba4c03852..f61eda07299 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPTransformer.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPTransformer.scala @@ -60,6 +60,15 @@ trait ConcurrencyParams extends Wrappable { /** @group setParam */ def setTimeout(value: Double): this.type = set(timeout, value) + val connectionTimeout: Param[Double] = new DoubleParam( + this, "connectionTimeout", "number of seconds to wait for establishing a connection") + + /** @group getParam */ + def getConnectionTimeout: Double = $(connectionTimeout) + + /** @group setParam */ + def setConnectionTimeout(value: Double): this.type = set(connectionTimeout, value) + val concurrentTimeout: Param[Double] = new DoubleParam( this, "concurrentTimeout", "max number seconds to wait on futures if concurrency >= 1") @@ -73,7 +82,7 @@ trait ConcurrencyParams extends Wrappable { case Some(v) => setConcurrentTimeout(v) case None => clear(concurrentTimeout) } - setDefault(concurrency -> 1, timeout -> 60.0) + setDefault(concurrency -> 1, timeout -> 60.0, connectionTimeout -> 5.0) } case object URLKey extends GlobalKey[String] @@ -106,13 +115,21 @@ class HTTPTransformer(val uid: String) val clientHolder = SharedVariable { getConcurrency match { - case 1 => new SingleThreadedHTTPClient(getHandler, (getTimeout * 1000).toInt) + case 1 => new SingleThreadedHTTPClient( + getHandler, + (getTimeout * 1000).toInt, + (getConnectionTimeout * 1000).toInt) case n if n > 1 => val dur = get(concurrentTimeout) .map(ct => Duration.fromNanos((ct * math.pow(10, 9)).toLong)) //scalastyle:ignore magic.number .getOrElse(Duration.Inf) val ec = ExecutionContext.global - new AsyncHTTPClient(getHandler, n, dur, (getTimeout * 1000).toInt)(ec) + new AsyncHTTPClient( + getHandler, + n, + dur, + (getTimeout * 1000).toInt, + (getConnectionTimeout * 1000).toInt)(ec) } } diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/SimpleHTTPTransformer.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/SimpleHTTPTransformer.scala index e30c3a1970d..2cd6df0e5fe 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/SimpleHTTPTransformer.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/SimpleHTTPTransformer.scala @@ -130,6 +130,7 @@ class SimpleHTTPTransformer(val uid: String) .setConcurrency(getConcurrency) .setConcurrentTimeout(get(concurrentTimeout)) .setTimeout(getTimeout) + .setConnectionTimeout(getConnectionTimeout) .setInputCol(parsedInputCol) .setOutputCol(unparsedOutputCol)) From 1f0b28d57c811287056d88ef846eec00827aadbd Mon Sep 17 00:00:00 2001 From: Ranadeep Singh Date: Tue, 16 Dec 2025 01:57:43 -0800 Subject: [PATCH 2/8] Push timeout changes --- .../ml/services/openai/OpenAIDefaults.py | 26 ++++-- .../ml/services/CognitiveServiceBase.scala | 25 +++--- .../synapse/ml/services/openai/OpenAI.scala | 5 +- .../ml/services/openai/OpenAIDefaults.scala | 27 ++++-- .../ml/services/openai/OpenAIPrompt.scala | 5 +- .../services/openai/test_OpenAIDefaults.py | 40 +++++++++ .../services/openai/OpenAIDefaultsSuite.scala | 43 +++++++++ .../services/openai/OpenAIPromptSuite.scala | 76 ++++++++++++++++ .../azure/synapse/ml/io/http/Clients.scala | 7 +- .../synapse/ml/io/http/HTTPClients.scala | 12 ++- .../azure/synapse/ml/io/http/HTTPSchema.scala | 6 +- .../synapse/ml/io/http/HTTPTransformer.scala | 78 +++++++++++++--- .../ml/io/http/SimpleHTTPTransformer.scala | 7 +- .../ml/io/split1/HTTPTransformerSuite.scala | 90 ++++++++++++++++++- .../split1/SimpleHTTPTransformerSuite.scala | 2 +- docs/Explore Algorithms/OpenAI/OpenAI.ipynb | 14 ++- 16 files changed, 406 insertions(+), 57 deletions(-) diff --git a/cognitive/src/main/python/synapse/ml/services/openai/OpenAIDefaults.py b/cognitive/src/main/python/synapse/ml/services/openai/OpenAIDefaults.py index e40a9f5a654..45a2a605342 100644 --- a/cognitive/src/main/python/synapse/ml/services/openai/OpenAIDefaults.py +++ b/cognitive/src/main/python/synapse/ml/services/openai/OpenAIDefaults.py @@ -138,17 +138,17 @@ def get_api_type(self): def reset_api_type(self): self.defaults.resetApiType() - def set_timeout(self, timeout): + def set_api_timeout(self, timeout): timeout_float = float(timeout) if timeout_float <= 0: - raise ValueError(f"Timeout must be greater than 0, got: {timeout_float}") - self.defaults.setTimeout(timeout_float) + raise ValueError(f"API timeout must be greater than 0, got: {timeout_float}") + self.defaults.setApiTimeout(timeout_float) - def get_timeout(self): - return getOption(self.defaults.getTimeout()) + def get_api_timeout(self): + return getOption(self.defaults.getApiTimeout()) - def reset_timeout(self): - self.defaults.resetTimeout() + def reset_api_timeout(self): + self.defaults.resetApiTimeout() def set_connection_timeout(self, timeout): timeout_float = float(timeout) @@ -163,3 +163,15 @@ def get_connection_timeout(self): def reset_connection_timeout(self): self.defaults.resetConnectionTimeout() + + def set_timeout(self, timeout): + timeout_float = float(timeout) + if timeout_float <= 0: + raise ValueError(f"Timeout must be greater than 0, got: {timeout_float}") + self.defaults.setTimeout(timeout_float) + + def get_timeout(self): + return getOption(self.defaults.getTimeout()) + + def reset_timeout(self): + self.defaults.resetTimeout() diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala index 2c4f7a2d5e4..cf9721a78ca 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala @@ -579,19 +579,22 @@ abstract class CognitiveServicesBaseNoHandler(val uid: String) extends Transform case l => l } + val baseTransformer = new SimpleHTTPTransformer() + .setInputCol(dynamicParamColName) + .setOutputCol(getOutputCol) + .setInputParser(getInternalInputParser(schema)) + .setOutputParser(getInternalOutputParser(schema)) + .setHandler(handlingFunc _) + .setConcurrency(getConcurrency) + .setConcurrentTimeout(get(concurrentTimeout)) + .setApiTimeout(getApiTimeout) + .setConnectionTimeout(getConnectionTimeout) + .setErrorCol(getErrorCol) + val transformer = get(timeout).map(baseTransformer.setTimeout).getOrElse(baseTransformer) + val stages = Array( Lambda(_.withColumn(dynamicParamColName, struct(dynamicParamCols: _*))), - new SimpleHTTPTransformer() - .setInputCol(dynamicParamColName) - .setOutputCol(getOutputCol) - .setInputParser(getInternalInputParser(schema)) - .setOutputParser(getInternalOutputParser(schema)) - .setHandler(handlingFunc _) - .setConcurrency(getConcurrency) - .setConcurrentTimeout(get(concurrentTimeout)) - .setTimeout(getTimeout) - .setConnectionTimeout(getConnectionTimeout) - .setErrorCol(getErrorCol), + transformer, new DropColumns().setCol(dynamicParamColName) ) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala index f20f93e481d..58960ca9616 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala @@ -108,8 +108,9 @@ case object OpenAITopPKey extends GlobalKey[Either[Double, String]] case object OpenAIVerbosityKey extends GlobalKey[Either[String, String]] case object OpenAIReasoningEffortKey extends GlobalKey[Either[String, String]] case object OpenAIApiTypeKey extends GlobalKey[String] -case object OpenAITimeoutKey extends GlobalKey[Double] +case object OpenAIApiTimeoutKey extends GlobalKey[Double] case object OpenAIConnectionTimeoutKey extends GlobalKey[Double] +case object OpenAITimeoutKey extends GlobalKey[Double] // scalastyle:off number.of.methods trait HasOpenAITextParams extends HasOpenAISharedParams { @@ -414,7 +415,7 @@ trait HasTextOutput { abstract class OpenAIServicesBase(override val uid: String) extends CognitiveServicesBase(uid: String) with HasOpenAISharedParams with OpenAIFabricSetting { - setDefault(timeout -> 600.0) + setDefault(apiTimeout -> 600.0) private def usingDefaultOpenAIEndpoint(): Boolean = { getUrl == FabricClient.MLWorkloadEndpointML + "/cognitive/openai/" diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala index d11113708fb..99ff6b1aa1d 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaults.scala @@ -161,17 +161,17 @@ object OpenAIDefaults { GlobalParams.resetGlobalParam(OpenAIApiTypeKey) } - def setTimeout(v: Double): Unit = { - require(v > 0, s"Timeout must be greater than 0, got: $v") - GlobalParams.setGlobalParam(OpenAITimeoutKey, v) + def setApiTimeout(v: Double): Unit = { + require(v > 0, s"API timeout must be greater than 0, got: $v") + GlobalParams.setGlobalParam(OpenAIApiTimeoutKey, v) } - def getTimeout: Option[Double] = { - GlobalParams.getGlobalParam(OpenAITimeoutKey) + def getApiTimeout: Option[Double] = { + GlobalParams.getGlobalParam(OpenAIApiTimeoutKey) } - def resetTimeout(): Unit = { - GlobalParams.resetGlobalParam(OpenAITimeoutKey) + def resetApiTimeout(): Unit = { + GlobalParams.resetGlobalParam(OpenAIApiTimeoutKey) } def setConnectionTimeout(v: Double): Unit = { @@ -187,6 +187,19 @@ object OpenAIDefaults { GlobalParams.resetGlobalParam(OpenAIConnectionTimeoutKey) } + def setTimeout(v: Double): Unit = { + require(v > 0, s"Timeout must be greater than 0, got: $v") + GlobalParams.setGlobalParam(OpenAITimeoutKey, v) + } + + def getTimeout: Option[Double] = { + GlobalParams.getGlobalParam(OpenAITimeoutKey) + } + + def resetTimeout(): Unit = { + GlobalParams.resetGlobalParam(OpenAITimeoutKey) + } + private def extractLeft[T](optEither: Option[Either[T, String]]): Option[T] = { optEither match { case Some(Left(v)) => Some(v) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala index 555cad3943c..4658aa2cbd2 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala @@ -49,8 +49,9 @@ class OpenAIPrompt(override val uid: String) extends Transformer logClass(FeatureNames.AiServices.OpenAI) - GlobalParams.registerParam(timeout, OpenAITimeoutKey) + GlobalParams.registerParam(apiTimeout, OpenAIApiTimeoutKey) GlobalParams.registerParam(connectionTimeout, OpenAIConnectionTimeoutKey) + GlobalParams.registerParam(timeout, OpenAITimeoutKey) def this() = this(Identifiable.randomUID("OpenAIPrompt")) @@ -179,7 +180,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer systemPrompt -> defaultSystemPrompt, apiType -> "chat_completions", columnTypes -> Map.empty, - timeout -> 600.0 + apiTimeout -> 600.0 ) override def setCustomServiceName(v: String): this.type = { diff --git a/cognitive/src/test/python/synapsemltest/services/openai/test_OpenAIDefaults.py b/cognitive/src/test/python/synapsemltest/services/openai/test_OpenAIDefaults.py index df96c8e908d..1eb9fc14fc3 100644 --- a/cognitive/src/test/python/synapsemltest/services/openai/test_OpenAIDefaults.py +++ b/cognitive/src/test/python/synapsemltest/services/openai/test_OpenAIDefaults.py @@ -30,6 +30,9 @@ def test_setters_and_getters(self): defaults.set_api_version("2024-05-01-preview") defaults.set_model("grok-3-mini") defaults.set_embedding_deployment_name("text-embedding-ada-002") + defaults.set_api_timeout(600.0) + defaults.set_connection_timeout(5.0) + defaults.set_timeout(120.0) self.assertEqual(defaults.get_deployment_name(), "Bing Bong") self.assertEqual(defaults.get_subscription_key(), "SubKey") @@ -42,6 +45,9 @@ def test_setters_and_getters(self): self.assertEqual( defaults.get_embedding_deployment_name(), "text-embedding-ada-002" ) + self.assertEqual(defaults.get_api_timeout(), 600.0) + self.assertEqual(defaults.get_connection_timeout(), 5.0) + self.assertEqual(defaults.get_timeout(), 120.0) def test_resetters(self): defaults = OpenAIDefaults() @@ -55,6 +61,9 @@ def test_resetters(self): defaults.set_api_version("2024-05-01-preview") defaults.set_model("grok-3-mini") defaults.set_embedding_deployment_name("text-embedding-ada-002") + defaults.set_api_timeout(600.0) + defaults.set_connection_timeout(5.0) + defaults.set_timeout(120.0) self.assertEqual(defaults.get_deployment_name(), "Bing Bong") self.assertEqual(defaults.get_subscription_key(), "SubKey") @@ -67,6 +76,9 @@ def test_resetters(self): self.assertEqual( defaults.get_embedding_deployment_name(), "text-embedding-ada-002" ) + self.assertEqual(defaults.get_api_timeout(), 600.0) + self.assertEqual(defaults.get_connection_timeout(), 5.0) + self.assertEqual(defaults.get_timeout(), 120.0) defaults.reset_deployment_name() defaults.reset_subscription_key() @@ -77,6 +89,9 @@ def test_resetters(self): defaults.reset_api_version() defaults.reset_model() defaults.reset_embedding_deployment_name() + defaults.reset_api_timeout() + defaults.reset_connection_timeout() + defaults.reset_timeout() self.assertEqual(defaults.get_deployment_name(), None) self.assertEqual(defaults.get_subscription_key(), None) @@ -87,6 +102,9 @@ def test_resetters(self): self.assertEqual(defaults.get_api_version(), None) self.assertEqual(defaults.get_model(), None) self.assertEqual(defaults.get_embedding_deployment_name(), None) + self.assertEqual(defaults.get_api_timeout(), None) + self.assertEqual(defaults.get_connection_timeout(), None) + self.assertEqual(defaults.get_timeout(), None) def test_two_defaults(self): defaults = OpenAIDefaults() @@ -168,6 +186,28 @@ def test_parameter_validation(self): with self.assertRaises(ValueError): defaults.set_top_p(1.1) + # Test valid timeout values + defaults.set_api_timeout(1.0) + defaults.set_api_timeout(600.0) + defaults.set_connection_timeout(1.0) + defaults.set_connection_timeout(5.0) + defaults.set_timeout(60.0) + defaults.set_timeout(120.0) + + # Test invalid timeout values (must be > 0) + with self.assertRaises(ValueError): + defaults.set_api_timeout(0.0) + with self.assertRaises(ValueError): + defaults.set_api_timeout(-1.0) + with self.assertRaises(ValueError): + defaults.set_connection_timeout(0.0) + with self.assertRaises(ValueError): + defaults.set_connection_timeout(-1.0) + with self.assertRaises(ValueError): + defaults.set_timeout(0.0) + with self.assertRaises(ValueError): + defaults.set_timeout(-1.0) + class TestResponseFormatJsonSchema(unittest.TestCase): def setUp(self): diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaultsSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaultsSuite.scala index d226ebe64e3..55b7f5c466a 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaultsSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIDefaultsSuite.scala @@ -88,6 +88,9 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey { OpenAIDefaults.setVerbosity("medium") OpenAIDefaults.setReasoningEffort("medium") OpenAIDefaults.setApiType("responses") + OpenAIDefaults.setApiTimeout(600.0) + OpenAIDefaults.setConnectionTimeout(5.0) + OpenAIDefaults.setTimeout(120.0) assert(OpenAIDefaults.getDeploymentName.contains(deploymentName)) assert(OpenAIDefaults.getSubscriptionKey.contains(openAIAPIKey)) @@ -101,6 +104,9 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey { assert(OpenAIDefaults.getVerbosity.contains("medium")) assert(OpenAIDefaults.getReasoningEffort.contains("medium")) assert(OpenAIDefaults.getApiType.contains("responses")) + assert(OpenAIDefaults.getApiTimeout.contains(600.0)) + assert(OpenAIDefaults.getConnectionTimeout.contains(5.0)) + assert(OpenAIDefaults.getTimeout.contains(120.0)) } test("Test Resetters") { @@ -116,6 +122,9 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey { OpenAIDefaults.setVerbosity("medium") OpenAIDefaults.setReasoningEffort("medium") OpenAIDefaults.setApiType("responses") + OpenAIDefaults.setApiTimeout(600.0) + OpenAIDefaults.setConnectionTimeout(5.0) + OpenAIDefaults.setTimeout(120.0) OpenAIDefaults.resetDeploymentName() OpenAIDefaults.resetSubscriptionKey() @@ -129,6 +138,9 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey { OpenAIDefaults.resetVerbosity() OpenAIDefaults.resetReasoningEffort() OpenAIDefaults.resetApiType() + OpenAIDefaults.resetApiTimeout() + OpenAIDefaults.resetConnectionTimeout() + OpenAIDefaults.resetTimeout() assert(OpenAIDefaults.getDeploymentName.isEmpty) assert(OpenAIDefaults.getSubscriptionKey.isEmpty) @@ -142,6 +154,9 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey { assert(OpenAIDefaults.getVerbosity.isEmpty) assert(OpenAIDefaults.getReasoningEffort.isEmpty) assert(OpenAIDefaults.getApiType.isEmpty) + assert(OpenAIDefaults.getApiTimeout.isEmpty) + assert(OpenAIDefaults.getConnectionTimeout.isEmpty) + assert(OpenAIDefaults.getTimeout.isEmpty) } test("Test Parameter Validation") { @@ -183,5 +198,33 @@ class OpenAIDefaultsSuite extends Flaky with OpenAIAPIKey { // Test reasoning effort values OpenAIDefaults.setReasoningEffort("low") OpenAIDefaults.setReasoningEffort("anything") + + // Test valid timeout values + OpenAIDefaults.setApiTimeout(1.0) + OpenAIDefaults.setApiTimeout(600.0) + OpenAIDefaults.setConnectionTimeout(1.0) + OpenAIDefaults.setConnectionTimeout(5.0) + OpenAIDefaults.setTimeout(60.0) + OpenAIDefaults.setTimeout(120.0) + + // Test invalid timeout values (must be > 0) + assertThrows[IllegalArgumentException] { + OpenAIDefaults.setApiTimeout(0.0) + } + assertThrows[IllegalArgumentException] { + OpenAIDefaults.setApiTimeout(-1.0) + } + assertThrows[IllegalArgumentException] { + OpenAIDefaults.setConnectionTimeout(0.0) + } + assertThrows[IllegalArgumentException] { + OpenAIDefaults.setConnectionTimeout(-1.0) + } + assertThrows[IllegalArgumentException] { + OpenAIDefaults.setTimeout(0.0) + } + assertThrows[IllegalArgumentException] { + OpenAIDefaults.setTimeout(-1.0) + } } } diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala index ac92b3e2b63..1e6cf664f78 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala @@ -406,6 +406,82 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK )) } + lazy val longInputDf: DataFrame = Seq( + ("apple", "fruits"), + (null, null), // scalastyle:ignore null + ("Flying on the weekends is a lot of fun! " * 1000, "travel"), + ("Flying is a lot of fun! " * 10000, "travel") + ).toDF("text", "category") + + test("Long Input Handling") { + val results = prompt + .setPromptTemplate("Summarize the following text in 10 words or less: {text}") + .setTimeout(120.0) + .transform(longInputDf) + .select("outParsed", prompt.getErrorCol) + .collect() + + // First 3 rows should have valid outputs + assert(Option(results(0).get(0)).isDefined) + assert(Option(results(1).get(0)).isDefined) + assert(Option(results(2).get(0)).isDefined) + + // Null input should return null output + assert(results(3).get(0) == null) + + // Long inputs should either succeed or have an error (token limit exceeded) + // Row 4: 1000 repetitions - may succeed or fail depending on model limits + // Row 5: 10000 repetitions - likely to exceed token limits + val row4HasOutput = Option(results(4).get(0)).isDefined + val row4HasError = Option(results(4).getAs[Row](1)).isDefined + assert(row4HasOutput || row4HasError, "Row 4 should have either output or error") + + val row5HasOutput = Option(results(5).get(0)).isDefined + val row5HasError = Option(results(5).getAs[Row](1)).isDefined + assert(row5HasOutput || row5HasError, "Row 5 should have either output or error") + } + + test("Timeout Configuration") { + // Test that timeout parameters are properly set and retrieved + val promptWithTimeout = new OpenAIPrompt() + .setSubscriptionKey(openAIAPIKey) + .setDeploymentName(deploymentName) + .setCustomServiceName(openAIServiceName) + .setOutputCol("outParsed") + .setApiTimeout(300.0) + .setConnectionTimeout(10.0) + .setTimeout(60.0) + + assert(promptWithTimeout.getApiTimeout == 300.0) + assert(promptWithTimeout.getConnectionTimeout == 10.0) + assert(promptWithTimeout.getTimeout == 60.0) + } + + test("Short Timeout Returns Timeout Error") { + val promptWithShortTimeout = new OpenAIPrompt() + .setSubscriptionKey(openAIAPIKey) + .setDeploymentName(deploymentName) + .setCustomServiceName(openAIServiceName) + .setOutputCol("outParsed") + .setPromptTemplate("List 5 {category}") + .setTimeout(0.001) // Very short timeout to force timeout error + + val results = promptWithShortTimeout + .transform(df) + .select("outParsed", promptWithShortTimeout.getErrorCol) + .collect() + + // All rows should have timeout errors due to very short timeout + results.foreach { row => + val errorRow = row.getAs[Row](1) + assert(errorRow != null, "Should have error due to timeout") + val errorResponse = errorRow.getAs[String]("response") + assert(errorResponse.contains("exceeded the time limit") || + errorResponse.contains("timeout"), + s"Error should mention timeout, got: $errorResponse") + } + } + override def assertDFEq(df1: DataFrame, df2: DataFrame)(implicit eq: Equality[DataFrame]): Unit = { super.assertDFEq(df1.drop("out", "outParsed"), df2.drop("out", "outParsed"))(eq) } diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/Clients.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/Clients.scala index c192a0fa81f..f4b6978b0eb 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/Clients.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/Clients.scala @@ -21,8 +21,11 @@ private[ml] trait BaseClient { def this(response: Option[ResponseType]) = this(response, None) } - case class RequestWithContext(request: Option[RequestType], context: Context) { - def this(request: Option[RequestType]) = this(request, None) + case class RequestWithContext(request: Option[RequestType], + context: Context, + precomputedResponse: Option[ResponseType] = None) { + def this(request: Option[RequestType]) = this(request, None, None) + def this(request: Option[RequestType], context: Context) = this(request, context, None) } protected lazy val logger: Logger = LogManager.getLogger("BaseClient") diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPClients.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPClients.scala index c2e47b4599d..79851048ea1 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPClients.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPClients.scala @@ -56,9 +56,15 @@ private[ml] trait HTTPClient extends BaseClient } protected def sendRequestWithContext(request: RequestWithContext): ResponseWithContext = { - request.request.map(req => - ResponseWithContext(Some(handle(internalClient, req)), request.context) - ).getOrElse(ResponseWithContext(None, request.context)) + // If there's a precomputed response (e.g., timeout), return it without making the HTTP call + request.precomputedResponse match { + case Some(response) => + ResponseWithContext(Some(response), request.context) + case None => + request.request.map(req => + ResponseWithContext(Some(handle(internalClient, req)), request.context) + ).getOrElse(ResponseWithContext(None, request.context)) + } } } diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPSchema.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPSchema.scala index 46ffbfc237c..ce3a540687e 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPSchema.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPSchema.scala @@ -290,7 +290,7 @@ object HTTPSchema { HTTPResponseData( Array(), Some(stringToEntity(x)), - StatusLineData(null, code, reason), + StatusLineData(ProtocolVersionData("HTTP", 1, 1), code, reason), "en") } @@ -304,7 +304,7 @@ object HTTPSchema { HTTPResponseData( Array(), None, - StatusLineData(null, code, reason), + StatusLineData(ProtocolVersionData("HTTP", 1, 1), code, reason), "en") } @@ -318,7 +318,7 @@ object HTTPSchema { HTTPResponseData( Array(), Some(binaryToEntity(x)), - StatusLineData(null, code, reason), + StatusLineData(ProtocolVersionData("HTTP", 1, 1), code, reason), "en") } diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPTransformer.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPTransformer.scala index f61eda07299..f1ade193328 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPTransformer.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPTransformer.scala @@ -51,14 +51,14 @@ trait ConcurrencyParams extends Wrappable { /** @group setParam */ def setConcurrency(value: Int): this.type = set(concurrency, value) - val timeout: Param[Double] = new DoubleParam( - this, "timeout", "number of seconds to wait before closing the connection") + val apiTimeout: Param[Double] = new DoubleParam( + this, "apiTimeout", "number of seconds to wait for a response from the API") /** @group getParam */ - def getTimeout: Double = $(timeout) + def getApiTimeout: Double = $(apiTimeout) /** @group setParam */ - def setTimeout(value: Double): this.type = set(timeout, value) + def setApiTimeout(value: Double): this.type = set(apiTimeout, value) val connectionTimeout: Param[Double] = new DoubleParam( this, "connectionTimeout", "number of seconds to wait for establishing a connection") @@ -69,6 +69,15 @@ trait ConcurrencyParams extends Wrappable { /** @group setParam */ def setConnectionTimeout(value: Double): this.type = set(connectionTimeout, value) + val timeout: Param[Double] = new DoubleParam( + this, "timeout", "number of seconds to wait for the entire operation to complete") + + /** @group getParam */ + def getTimeout: Double = $(timeout) + + /** @group setParam */ + def setTimeout(value: Double): this.type = set(timeout, value) + val concurrentTimeout: Param[Double] = new DoubleParam( this, "concurrentTimeout", "max number seconds to wait on futures if concurrency >= 1") @@ -82,7 +91,7 @@ trait ConcurrencyParams extends Wrappable { case Some(v) => setConcurrentTimeout(v) case None => clear(concurrentTimeout) } - setDefault(concurrency -> 1, timeout -> 60.0, connectionTimeout -> 5.0) + setDefault(concurrency -> 1, apiTimeout -> 60.0, connectionTimeout -> 5.0) } case object URLKey extends GlobalKey[String] @@ -117,7 +126,7 @@ class HTTPTransformer(val uid: String) getConcurrency match { case 1 => new SingleThreadedHTTPClient( getHandler, - (getTimeout * 1000).toInt, + (getApiTimeout * 1000).toInt, (getConnectionTimeout * 1000).toInt) case n if n > 1 => val dur = get(concurrentTimeout) @@ -128,11 +137,19 @@ class HTTPTransformer(val uid: String) getHandler, n, dur, - (getTimeout * 1000).toInt, + (getApiTimeout * 1000).toInt, (getConnectionTimeout * 1000).toInt)(ec) } } + private def createTimeoutResponse(timeoutSeconds: Double): HTTPResponseData = { + HTTPSchema.stringToResponse( + f"The operation exceeded the time limit of $timeoutSeconds%.1f seconds. " + + "Fix: increase value of timeout or reset it for no limit.", + 408, //scalastyle:ignore magic.number HTTP_REQUEST_TIMEOUT + "Request Timeout") + } + /** @param dataset - The input dataset, to be transformed * @return The DataFrame that results from column selection */ @@ -143,16 +160,51 @@ class HTTPTransformer(val uid: String) val colIndex = df.schema.fieldNames.indexOf(getInputCol) val fromRow = HTTPRequestData.makeFromRowConverter val toRow = HTTPResponseData.makeToRowConverter + + val timeoutMs = get(timeout).map(t => (t * 1000).toLong) + val startTime = timeoutMs.map(_ => System.currentTimeMillis()) + val startTimeBroadcast = startTime.map(df.sparkSession.sparkContext.broadcast(_)) + + val timeoutSec = get(timeout) + df.mapPartitions { it => if (!it.hasNext) { Iterator() } else { - val c = clientHolder.get - val responsesWithContext = c.sendRequestsWithContext(it.map { row => - c.RequestWithContext(Option(row.getStruct(colIndex)).map(fromRow), Some(row)) - }) - responsesWithContext.map { rwc => - Row.fromSeq(rwc.context.get.asInstanceOf[Row].toSeq :+ rwc.response.flatMap(Option.apply).map(toRow).orNull) + // Early timeout check before creating HTTP client to avoid network errors + val isAlreadyTimedOut = (timeoutMs, startTimeBroadcast) match { + case (Some(tm), Some(startBroadcast)) => + System.currentTimeMillis() - startBroadcast.value > tm + case _ => false + } + + if (isAlreadyTimedOut) { + // Return timeout responses for all rows without creating HTTP client + it.map { row => + Row.fromSeq(row.toSeq :+ toRow(createTimeoutResponse(timeoutSec.get))) + } + } else { + val c = clientHolder.get + + val responsesWithContext = c.sendRequestsWithContext(it.map { row => + // Check if timeout has been exceeded + val isTimedOut = (timeoutMs, startTimeBroadcast) match { + case (Some(tm), Some(startBroadcast)) => + System.currentTimeMillis() - startBroadcast.value > tm + case _ => false + } + + if (isTimedOut) { + // Return a timeout response without making the API call + c.RequestWithContext(None, Some(row), Some(createTimeoutResponse(timeoutSec.get))) + } else { + c.RequestWithContext(Option(row.getStruct(colIndex)).map(fromRow), Some(row)) + } + }) + responsesWithContext.map { rwc => + Row.fromSeq(rwc.context.get.asInstanceOf[Row].toSeq :+ + rwc.response.flatMap(Option.apply).map(toRow).orNull) + } } } }(enc) diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/SimpleHTTPTransformer.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/SimpleHTTPTransformer.scala index 2cd6df0e5fe..5f21edde250 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/SimpleHTTPTransformer.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/SimpleHTTPTransformer.scala @@ -125,14 +125,15 @@ class SimpleHTTPTransformer(val uid: String) .setInputCol(getInputCol) .setOutputCol(parsedInputCol)) - val client = Some(new HTTPTransformer() + val baseClient = new HTTPTransformer() .setHandler(getHandler) .setConcurrency(getConcurrency) .setConcurrentTimeout(get(concurrentTimeout)) - .setTimeout(getTimeout) + .setApiTimeout(getApiTimeout) .setConnectionTimeout(getConnectionTimeout) .setInputCol(parsedInputCol) - .setOutputCol(unparsedOutputCol)) + .setOutputCol(unparsedOutputCol) + val client = Some(get(timeout).map(baseClient.setTimeout).getOrElse(baseClient)) val parseErrors = Some(Lambda(_ .withColumn(getErrorCol, ErrorUtils.addErrorUDF(col(unparsedOutputCol))) diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/io/split1/HTTPTransformerSuite.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/io/split1/HTTPTransformerSuite.scala index 3ecedf52382..a88f0a17518 100644 --- a/core/src/test/scala/com/microsoft/azure/synapse/ml/io/split1/HTTPTransformerSuite.scala +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/io/split1/HTTPTransformerSuite.scala @@ -7,10 +7,10 @@ import com.microsoft.azure.synapse.ml.core.env.StreamUtilities import com.microsoft.azure.synapse.ml.core.env.StreamUtilities.using import com.microsoft.azure.synapse.ml.core.test.base.TestBase import com.microsoft.azure.synapse.ml.core.test.fuzzing.{TestObject, TransformerFuzzing} -import com.microsoft.azure.synapse.ml.io.http.HTTPTransformer +import com.microsoft.azure.synapse.ml.io.http.{HTTPResponseData, HTTPTransformer, JSONInputParser} import com.sun.net.httpserver.{HttpExchange, HttpHandler, HttpServer} import org.apache.spark.ml.util.MLReadable -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.scalactic.Equality import java.net.{InetSocketAddress, ServerSocket} @@ -123,4 +123,90 @@ class HTTPTransformerSuite extends TransformerFuzzing[HTTPTransformer] } } + test("HTTPTransformer should have correct default timeout values") { + val transformer = new HTTPTransformer() + .setInputCol("parsedInput") + .setOutputCol("out") + + assert(transformer.getApiTimeout === 60.0) + assert(transformer.getConnectionTimeout === 5.0) + assert(transformer.getConcurrency === 1) + } + + test("HTTPTransformer should allow setting custom timeout values") { + val transformer = new HTTPTransformer() + .setInputCol("parsedInput") + .setOutputCol("out") + .setApiTimeout(120.0) + .setConnectionTimeout(10.0) + .setConcurrency(5) + + assert(transformer.getApiTimeout === 120.0) + assert(transformer.getConnectionTimeout === 10.0) + assert(transformer.getConcurrency === 5) + } + + test("HTTPTransformer should allow setting global operation timeout") { + val transformer = new HTTPTransformer() + .setInputCol("parsedInput") + .setOutputCol("out") + .setTimeout(30.0) + + assert(transformer.getTimeout === 30.0) + } + + test("HTTPTransformer global timeout returns 408 for timed out requests") { + import spark.implicits._ + + // Create a server that responds slowly + val slowServer = ServerUtils.createServiceOnFreePort("slow", handler = new com.sun.net.httpserver.HttpHandler { + override def handle(request: com.sun.net.httpserver.HttpExchange): Unit = { + Thread.sleep(5000) // 5 second delay + val response = "{\"result\": \"done\"}" + request.getResponseHeaders.add("Content-Type", "application/json") + request.sendResponseHeaders(200, response.length) + val os = request.getResponseBody + os.write(response.getBytes) + os.close() + } + }) + + try { + val slowUrl = s"http://localhost:${slowServer.getAddress.getPort}/slow" + + val df = sc.parallelize((1 to 3).map(Tuple1(_))).toDF("data") + val parsedDf = new JSONInputParser() + .setInputCol("data") + .setOutputCol("parsedInput") + .setUrl(slowUrl) + .transform(df) + + val transformer = new HTTPTransformer() + .setInputCol("parsedInput") + .setOutputCol("out") + .setTimeout(0.5) // 0.5 second timeout - should timeout before server responds + .setApiTimeout(10.0) // Long enough not to timeout on API level + + val results = transformer.transform(parsedDf).collect() + assert(results.length === 3) + + // Check that some requests got timeout response (HTTP 408) + val fromRow = HTTPResponseData.makeFromRowConverter + val timedOutResponses = results.filter { row => + val responseRow = row.getAs[Row]("out") + if (responseRow != null) { + val response = fromRow(responseRow) + response.statusLine.statusCode == 408 + } else { + false + } + } + + // At least some responses should be timed out (exact count depends on timing) + assert(timedOutResponses.nonEmpty || results.exists(_.getAs[Row]("out") != null)) + } finally { + slowServer.stop(0) + } + } + } diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/io/split1/SimpleHTTPTransformerSuite.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/io/split1/SimpleHTTPTransformerSuite.scala index 9ecaf832f8c..e709ccb8865 100644 --- a/core/src/test/scala/com/microsoft/azure/synapse/ml/io/split1/SimpleHTTPTransformerSuite.scala +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/io/split1/SimpleHTTPTransformerSuite.scala @@ -36,7 +36,7 @@ class SimpleHTTPTransformerSuite lazy val df2: DataFrame = sc.parallelize((1 to 5).map(Tuple1(_))).toDF("data") val results = simpleTransformer .setUrl(url + "/flaky") - .setTimeout(1) + .setApiTimeout(1) .transform(df2).collect assert(results.length == 5) } diff --git a/docs/Explore Algorithms/OpenAI/OpenAI.ipynb b/docs/Explore Algorithms/OpenAI/OpenAI.ipynb index 03b3646f943..db39d51f28f 100644 --- a/docs/Explore Algorithms/OpenAI/OpenAI.ipynb +++ b/docs/Explore Algorithms/OpenAI/OpenAI.ipynb @@ -257,6 +257,18 @@ ")" ] }, + { + "cell_type": "markdown", + "source": "### Configuring Timeouts\n\nSynapseML OpenAI connectors support configurable timeout parameters to handle long-running requests and prevent hanging operations:\n\n- **`api_timeout`**: Number of seconds to wait for a response from the API (default: 600 seconds / 10 minutes). This matches the OpenAI Python SDK default.\n- **`connection_timeout`**: Number of seconds to wait for establishing a connection (default: 5 seconds).\n- **`timeout`**: Global operation timeout in seconds. When set, the entire DataFrame transformation will stop making new API calls after this duration, returning HTTP 408 (Request Timeout) errors for any remaining rows. Useful for batch processing with time constraints.\n\nThese parameters can be set globally using `OpenAIDefaults` or on individual transformer instances.", + "metadata": {} + }, + { + "cell_type": "code", + "source": "from synapse.ml.services.openai import OpenAIChatCompletion\nfrom synapse.ml.services.openai.OpenAIDefaults import OpenAIDefaults\n\n# Set global timeout defaults\ndefaults = OpenAIDefaults()\ndefaults.set_deployment_name(deployment_name)\ndefaults.set_subscription_key(key)\ndefaults.set_URL(f\"https://{service_name}.openai.azure.com/\")\ndefaults.set_api_timeout(600.0) # 10 minutes API timeout (default)\ndefaults.set_connection_timeout(5.0) # 5 seconds connection timeout (default)\ndefaults.set_timeout(300.0) # 5 minutes global operation timeout\n\n# Transformers will automatically use these defaults\nchat_completion_with_timeouts = (\n OpenAIChatCompletion()\n .setMessagesCol(\"messages\")\n .setOutputCol(\"chat_completions\")\n .setErrorCol(\"chat_completions_error\")\n)\n\n# Alternatively, set timeouts directly on the transformer instance\nchat_completion_custom_timeout = (\n OpenAIChatCompletion()\n .setSubscriptionKey(key)\n .setDeploymentName(deployment_name)\n .setCustomServiceName(service_name)\n .setMessagesCol(\"messages\")\n .setOutputCol(\"chat_completions\")\n .setErrorCol(\"chat_completions_error\")\n .setApiTimeout(300.0) # 5 minutes API timeout\n .setConnectionTimeout(10.0) # 10 seconds connection timeout\n .setTimeout(120.0) # 2 minutes global operation timeout\n)\n\ndisplay(\n chat_completion_with_timeouts.transform(chat_df)\n .select(\"messages\", \"chat_completions.choices.message.content\")\n .show(truncate=False)\n)", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, { "attachments": {}, "cell_type": "markdown", @@ -920,4 +932,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} +} \ No newline at end of file From 817851cb72dfa633a0c5710bfa08c514bed60e70 Mon Sep 17 00:00:00 2001 From: Ranadeep Singh Date: Tue, 16 Dec 2025 02:10:15 -0800 Subject: [PATCH 3/8] Fix black python formatting --- .../main/python/synapse/ml/services/openai/OpenAIDefaults.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cognitive/src/main/python/synapse/ml/services/openai/OpenAIDefaults.py b/cognitive/src/main/python/synapse/ml/services/openai/OpenAIDefaults.py index 45a2a605342..69ffa656a57 100644 --- a/cognitive/src/main/python/synapse/ml/services/openai/OpenAIDefaults.py +++ b/cognitive/src/main/python/synapse/ml/services/openai/OpenAIDefaults.py @@ -141,7 +141,9 @@ def reset_api_type(self): def set_api_timeout(self, timeout): timeout_float = float(timeout) if timeout_float <= 0: - raise ValueError(f"API timeout must be greater than 0, got: {timeout_float}") + raise ValueError( + f"API timeout must be greater than 0, got: {timeout_float}" + ) self.defaults.setApiTimeout(timeout_float) def get_api_timeout(self): From 01c72384edef513b50e980dfeccce85f95f926cf Mon Sep 17 00:00:00 2001 From: Ranadeep Singh Date: Tue, 16 Dec 2025 02:58:06 -0800 Subject: [PATCH 4/8] Update to Long input test --- .../services/openai/OpenAIPromptSuite.scala | 31 +++++++++---------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala index 1e6cf664f78..40f0a8915a4 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPromptSuite.scala @@ -421,24 +421,21 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK .select("outParsed", prompt.getErrorCol) .collect() - // First 3 rows should have valid outputs + // Row 0: "apple" - normal input, should have valid output assert(Option(results(0).get(0)).isDefined) - assert(Option(results(1).get(0)).isDefined) - assert(Option(results(2).get(0)).isDefined) - - // Null input should return null output - assert(results(3).get(0) == null) - - // Long inputs should either succeed or have an error (token limit exceeded) - // Row 4: 1000 repetitions - may succeed or fail depending on model limits - // Row 5: 10000 repetitions - likely to exceed token limits - val row4HasOutput = Option(results(4).get(0)).isDefined - val row4HasError = Option(results(4).getAs[Row](1)).isDefined - assert(row4HasOutput || row4HasError, "Row 4 should have either output or error") - - val row5HasOutput = Option(results(5).get(0)).isDefined - val row5HasError = Option(results(5).getAs[Row](1)).isDefined - assert(row5HasOutput || row5HasError, "Row 5 should have either output or error") + + // Row 1: null input should return null output + assert(results(1).get(0) == null) + + // Row 2: 1000 repetitions - may succeed or fail depending on model limits + val row2HasOutput = Option(results(2).get(0)).isDefined + val row2HasError = Option(results(2).getAs[Row](1)).isDefined + assert(row2HasOutput || row2HasError, "Row 2 should have either output or error") + + // Row 3: 10000 repetitions - possible to exceed token limits + val row3HasOutput = Option(results(3).get(0)).isDefined + val row3HasError = Option(results(3).getAs[Row](1)).isDefined + assert(row3HasOutput || row3HasError, "Row 3 should have either output or error") } test("Timeout Configuration") { From 48617a962918a2f34e4f777b5de5353c01550eb0 Mon Sep 17 00:00:00 2001 From: Rana Singh Date: Wed, 17 Dec 2025 20:56:27 -0800 Subject: [PATCH 5/8] Update core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPTransformer.scala Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../microsoft/azure/synapse/ml/io/http/HTTPTransformer.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPTransformer.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPTransformer.scala index f1ade193328..72acfb848f7 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPTransformer.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPTransformer.scala @@ -70,7 +70,9 @@ trait ConcurrencyParams extends Wrappable { def setConnectionTimeout(value: Double): this.type = set(connectionTimeout, value) val timeout: Param[Double] = new DoubleParam( - this, "timeout", "number of seconds to wait for the entire operation to complete") + this, "timeout", + "number of seconds for the entire DataFrame transformation to complete, measured from the start of the transform operation; " + + "rows processed after this timeout will receive HTTP 408 responses without making API calls") /** @group getParam */ def getTimeout: Double = $(timeout) From 27ab5e1fffbe68778d5de9f99394461aba9a6ee3 Mon Sep 17 00:00:00 2001 From: Ranadeep Singh Date: Wed, 17 Dec 2025 21:34:53 -0800 Subject: [PATCH 6/8] Fix scala style issue --- .../azure/synapse/ml/io/http/HTTPTransformer.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPTransformer.scala b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPTransformer.scala index 72acfb848f7..75108a712a9 100644 --- a/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPTransformer.scala +++ b/core/src/main/scala/com/microsoft/azure/synapse/ml/io/http/HTTPTransformer.scala @@ -71,8 +71,10 @@ trait ConcurrencyParams extends Wrappable { val timeout: Param[Double] = new DoubleParam( this, "timeout", - "number of seconds for the entire DataFrame transformation to complete, measured from the start of the transform operation; " + - "rows processed after this timeout will receive HTTP 408 responses without making API calls") + "number of seconds for the entire DataFrame transformation to complete, " + + "measured from the start of the transform operation; " + + "rows processed after this timeout will receive HTTP 408 responses without making API calls" + ) /** @group getParam */ def getTimeout: Double = $(timeout) From 25c0453595bfeaf77347588cb05bc13a79366722 Mon Sep 17 00:00:00 2001 From: Ranadeep Singh Date: Thu, 18 Dec 2025 02:14:07 -0800 Subject: [PATCH 7/8] Add timeouts to OpenAIEmbedding --- .../azure/synapse/ml/services/openai/OpenAIEmbedding.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbedding.scala b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbedding.scala index e980507066d..cfe1530af79 100644 --- a/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbedding.scala +++ b/cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbedding.scala @@ -28,6 +28,10 @@ class OpenAIEmbedding (override val uid: String) extends OpenAIServicesBase(uid) with HasReturnUsage { logClass(FeatureNames.AiServices.OpenAI) + GlobalParams.registerParam(apiTimeout, OpenAIApiTimeoutKey) + GlobalParams.registerParam(connectionTimeout, OpenAIConnectionTimeoutKey) + GlobalParams.registerParam(timeout, OpenAITimeoutKey) + def this() = this(Identifiable.randomUID("OpenAIEmbedding")) def urlPath: String = "" From a05b321764e3f924877a4046b4ed65394cf46d4c Mon Sep 17 00:00:00 2001 From: Ranadeep Singh Date: Thu, 18 Dec 2025 02:25:36 -0800 Subject: [PATCH 8/8] Add tests to OpenAIEmbeddignSuite and renamed it --- ...Suite.scala => OpenAIEmbeddingSuite.scala} | 82 ++++++++++++++++++- 1 file changed, 81 insertions(+), 1 deletion(-) rename cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/{OpenAIEmbeddingsSuite.scala => OpenAIEmbeddingSuite.scala} (71%) diff --git a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbeddingsSuite.scala b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbeddingSuite.scala similarity index 71% rename from cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbeddingsSuite.scala rename to cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbeddingSuite.scala index beeb11afde7..552265121d3 100644 --- a/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbeddingsSuite.scala +++ b/cognitive/src/test/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbeddingSuite.scala @@ -11,7 +11,7 @@ import org.apache.spark.ml.linalg.{Vector, SQLDataTypes} import org.apache.spark.sql.types.StructType import org.scalactic.Equality -class OpenAIEmbeddingsSuite extends TransformerFuzzing[OpenAIEmbedding] with OpenAIAPIKey with Flaky { +class OpenAIEmbeddingSuite extends TransformerFuzzing[OpenAIEmbedding] with OpenAIAPIKey with Flaky { import spark.implicits._ @@ -203,6 +203,86 @@ class OpenAIEmbeddingsSuite extends TransformerFuzzing[OpenAIEmbedding] with Ope assert(results(2).getAs[Row]("null_test_usage") != null) } + test("Timeout Configuration") { + // Test that timeout parameters are properly set and retrieved + val embeddingWithTimeout = new OpenAIEmbedding() + .setSubscriptionKey(openAIAPIKey) + .setCustomServiceName(openAIServiceName) + .setDeploymentName("text-embedding-ada-002") + .setTextCol("text") + .setOutputCol("out") + .setApiTimeout(300.0) + .setConnectionTimeout(10.0) + .setTimeout(60.0) + + assert(embeddingWithTimeout.getApiTimeout == 300.0) + assert(embeddingWithTimeout.getConnectionTimeout == 10.0) + assert(embeddingWithTimeout.getTimeout == 60.0) + } + + test("Short Timeout Returns Timeout Error") { + val embeddingWithShortTimeout = new OpenAIEmbedding() + .setSubscriptionKey(openAIAPIKey) + .setCustomServiceName(openAIServiceName) + .setDeploymentName("text-embedding-ada-002") + .setTextCol("text") + .setOutputCol("out") + .setTimeout(0.001) // Very short timeout to force timeout error + + val errorCol = embeddingWithShortTimeout.getErrorCol + val results = embeddingWithShortTimeout + .transform(df) + .select("out", errorCol) + .collect() + + // All rows should have timeout errors due to very short timeout + results.foreach { row => + val errorRow = row.getAs[Row](1) + assert(errorRow != null, "Should have error due to timeout") + val errorResponse = errorRow.getAs[String]("response") + assert(errorResponse.contains("exceeded the time limit") || + errorResponse.contains("timeout"), + s"Error should mention timeout, got: $errorResponse") + } + } + + test("Embedding uses global timeout defaults from OpenAIDefaults") { + val originalApiTimeout = OpenAIDefaults.getApiTimeout + val originalConnectionTimeout = OpenAIDefaults.getConnectionTimeout + val originalTimeout = OpenAIDefaults.getTimeout + + try { + OpenAIDefaults.setApiTimeout(350.0) + OpenAIDefaults.setConnectionTimeout(15.0) + OpenAIDefaults.setTimeout(120.0) + + val e = new OpenAIEmbedding() + .setSubscriptionKey(openAIAPIKey) + .setCustomServiceName(openAIServiceName) + .setDeploymentName("text-embedding-ada-002") + .setTextCol("text") + .setOutputCol("out") + + // The embedding should work with global defaults applied + val results = e.transform(df.limit(1)).collect() + assert(results.length == 1) + assert(results(0).getAs[Vector]("out").size > 0) + } finally { + // Reset global defaults to avoid cross-test contamination + originalApiTimeout match { + case Some(v) => OpenAIDefaults.setApiTimeout(v) + case None => OpenAIDefaults.resetApiTimeout() + } + originalConnectionTimeout match { + case Some(v) => OpenAIDefaults.setConnectionTimeout(v) + case None => OpenAIDefaults.resetConnectionTimeout() + } + originalTimeout match { + case Some(v) => OpenAIDefaults.setTimeout(v) + case None => OpenAIDefaults.resetTimeout() + } + } + } override def testObjects(): Seq[TestObject[OpenAIEmbedding]] = Seq(new TestObject(embedding, df))