diff --git a/dbtsl/__about__.py b/dbtsl/__about__.py index f5a5cb4..bec33f0 100644 --- a/dbtsl/__about__.py +++ b/dbtsl/__about__.py @@ -1 +1 @@ -VERSION = "0.13.1" +VERSION = "0.13.2" diff --git a/dbtsl/api/graphql/client/asyncio.py b/dbtsl/api/graphql/client/asyncio.py index 645729e..4c2a3d2 100644 --- a/dbtsl/api/graphql/client/asyncio.py +++ b/dbtsl/api/graphql/client/asyncio.py @@ -142,7 +142,7 @@ async def _poll_until_complete( elapsed_s = time.time() - start_s if elapsed_s > total_timeout_s: - raise RetryTimeoutError(timeout_s=total_timeout_s) + raise RetryTimeoutError(timeout_s=total_timeout_s, status=qr.status.value) await asyncio.sleep(sleep_ms / 1000) diff --git a/dbtsl/api/graphql/client/sync.py b/dbtsl/api/graphql/client/sync.py index 1e996fa..32725b5 100644 --- a/dbtsl/api/graphql/client/sync.py +++ b/dbtsl/api/graphql/client/sync.py @@ -130,7 +130,7 @@ def _poll_until_complete( elapsed_s = time.time() - start_s if elapsed_s > total_timeout_s: - raise RetryTimeoutError(timeout_s=total_timeout_s) + raise RetryTimeoutError(timeout_s=total_timeout_s, status=qr.status.value) time.sleep(sleep_ms / 1000) diff --git a/dbtsl/error.py b/dbtsl/error.py index 1fef7d8..28395c8 100644 --- a/dbtsl/error.py +++ b/dbtsl/error.py @@ -43,6 +43,22 @@ class ExecuteTimeoutError(TimeoutError): class RetryTimeoutError(TimeoutError): """Raise whenever a timeout occurred while retrying an operation against the servers.""" + def __init__(self, *, timeout_s: float, status: Optional[str] = None, **_kwargs: object) -> None: + """Initialize the retry timeout error. + + Args: + timeout_s: The maximum time limit that got exceeded, in seconds + status: The last known query status before the timeout occurred + **_kwargs: any other exception kwargs + """ + super().__init__(timeout_s=timeout_s) + self.status = status + + def __str__(self) -> str: # noqa: D105 + if self.status is not None: + return f"{self.__class__.__name__}(timeout_s={self.timeout_s}, status={self.status})" + return f"{self.__class__.__name__}(timeout_s={self.timeout_s})" + class QueryFailedError(SemanticLayerError): """Raise whenever a query has failed.""" diff --git a/tests/api/graphql/test_client.py b/tests/api/graphql/test_client.py index 6ef7b45..e33088b 100644 --- a/tests/api/graphql/test_client.py +++ b/tests/api/graphql/test_client.py @@ -10,6 +10,7 @@ from dbtsl.api.graphql.client.asyncio import AsyncGraphQLClient from dbtsl.api.graphql.client.sync import SyncGraphQLClient from dbtsl.api.graphql.protocol import GetQueryResultVariables, GraphQLProtocol, ProtocolOperation +from dbtsl.error import RetryTimeoutError from dbtsl.models.query import QueryId, QueryResult, QueryStatus # The following 2 tests are copies of each other since testing the same sync/async functionality is @@ -145,3 +146,63 @@ def run_behavior(op: ProtocolOperation[Any, Any], raw_variables: GetQueryResultV ) assert result_table.equals(table, check_metadata=True) + + +# avoid raising mock warning related to mocking a context manager +@pytest.mark.filterwarnings("ignore::pytest_mock.PytestMockWarning") +def test_sync_poll_timeout_includes_status(mocker: MockerFixture) -> None: + """Test that RetryTimeoutError includes the last known query status.""" + client = SyncGraphQLClient(server_host="test", environment_id=0, auth_token="test", timeout=0.001, lazy=False) + + compiled_result = QueryResult( + query_id=QueryId("test-query-id"), + status=QueryStatus.COMPILED, + sql=None, + error=None, + total_pages=None, + arrow_result=None, + ) + + run_mock = MagicMock(return_value=compiled_result) + mocker.patch.object(client, "_run", new=run_mock) + + mocker.patch.object(client, "create_query", return_value=QueryId("test-query-id")) + + gql_mock = mocker.patch.object(client, "_gql") + mocker.patch.object(gql_mock, "__aenter__") + mocker.patch("dbtsl.api.graphql.client.sync.isinstance", return_value=True) + + with client.session(): + with pytest.raises(RetryTimeoutError) as exc_info: + client.query(metrics=["m1"]) + + assert exc_info.value.status == "COMPILED" + + +async def test_async_poll_timeout_includes_status(mocker: MockerFixture) -> None: + """Test that RetryTimeoutError includes the last known query status (async).""" + client = AsyncGraphQLClient(server_host="test", environment_id=0, auth_token="test", timeout=0.001, lazy=False) + + compiled_result = QueryResult( + query_id=QueryId("test-query-id"), + status=QueryStatus.COMPILED, + sql=None, + error=None, + total_pages=None, + arrow_result=None, + ) + + run_mock = AsyncMock(return_value=compiled_result) + mocker.patch.object(client, "_run", new=run_mock) + + mocker.patch.object(client, "create_query", return_value=QueryId("test-query-id"), new_callable=AsyncMock) + + gql_mock = mocker.patch.object(client, "_gql") + mocker.patch.object(gql_mock, "__aenter__", new_callable=AsyncMock) + mocker.patch("dbtsl.api.graphql.client.asyncio.isinstance", return_value=True) + + async with client.session(): + with pytest.raises(RetryTimeoutError) as exc_info: + await client.query(metrics=["m1"]) + + assert exc_info.value.status == "COMPILED" diff --git a/tests/test_error.py b/tests/test_error.py index 9f9730c..a23c3ec 100644 --- a/tests/test_error.py +++ b/tests/test_error.py @@ -1,4 +1,4 @@ -from dbtsl.error import SemanticLayerError, TimeoutError +from dbtsl.error import RetryTimeoutError, SemanticLayerError, TimeoutError def test_error_str_calls_repr() -> None: @@ -15,3 +15,17 @@ def test_error_repr_with_args() -> None: def test_timeout_error_str() -> None: assert str(TimeoutError(timeout_s=1000)) == "TimeoutError(timeout_s=1000)" + + +def test_retry_timeout_error_without_status() -> None: + err = RetryTimeoutError(timeout_s=60) + assert err.timeout_s == 60 + assert err.status is None + assert str(err) == "RetryTimeoutError(timeout_s=60)" + + +def test_retry_timeout_error_with_status() -> None: + err = RetryTimeoutError(timeout_s=30, status="COMPILED") + assert err.timeout_s == 30 + assert err.status == "COMPILED" + assert str(err) == "RetryTimeoutError(timeout_s=30, status=COMPILED)"