diff --git a/doc/changelog.d/325.fixed.md b/doc/changelog.d/325.fixed.md new file mode 100644 index 00000000..039c4e89 --- /dev/null +++ b/doc/changelog.d/325.fixed.md @@ -0,0 +1 @@ +Fix/token timeout 2 diff --git a/src/ansys/conceptev/core/app.py b/src/ansys/conceptev/core/app.py index ebe83d63..5b3aa030 100644 --- a/src/ansys/conceptev/core/app.py +++ b/src/ansys/conceptev/core/app.py @@ -136,7 +136,7 @@ def get_http_client( client.send = retry( retry=retry_if_result(is_gateway_error), wait=wait_random_exponential(multiplier=1, max=60), - stop=stop_after_delay(10), + stop=stop_after_delay(120), )(client.send) return client diff --git a/src/ansys/conceptev/core/auth.py b/src/ansys/conceptev/core/auth.py index c0efb6e8..6d1cb30a 100644 --- a/src/ansys/conceptev/core/auth.py +++ b/src/ansys/conceptev/core/auth.py @@ -100,7 +100,12 @@ def auth_flow(self, request): """Send the request, with a custom `Authentication` header.""" token = get_ansyId_token(self.app) request.headers["Authorization"] = token - yield request + response = yield request + if response.status_code == 401: + logger.info("Token expired or rejected (401). Refreshing token and retrying.") + token = get_ansyId_token(self.app, force=True) + request.headers["Authorization"] = token + yield request def get_token(client: httpx.Client) -> str: diff --git a/src/ansys/conceptev/core/progress.py b/src/ansys/conceptev/core/progress.py index c8e05af0..c4bb25c2 100644 --- a/src/ansys/conceptev/core/progress.py +++ b/src/ansys/conceptev/core/progress.py @@ -31,6 +31,7 @@ from msal import PublicClientApplication from websockets.asyncio.client import connect +from ansys.conceptev.core.auth import get_ansyId_token from ansys.conceptev.core.settings import settings if sys.version_info >= (3, 11): @@ -163,7 +164,7 @@ def monitor_job_progress( if __name__ == "__main__": """Monitor a single job progress.""" from ansys.conceptev.core.app import get_user_id - from ansys.conceptev.core.auth import create_msal_app, get_ansyId_token + from ansys.conceptev.core.auth import create_msal_app job_id = "ae3f3b4b-91d8-4cdd-8fa3-25eb202a561e" # Replace with your job ID msal_app = create_msal_app() diff --git a/tests/test_auth.py b/tests/test_auth.py index adc89eb5..d8647146 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -123,3 +123,71 @@ def test_auth_flow_adds_authorization_header(mocker, httpx_mock: HTTPXMock): client = httpx.Client(auth=auth_instance) response = client.get("http://example.com") assert response.request.headers["Authorization"] == "auth_class_token" + + +def test_auth_flow_retries_on_401_with_fresh_token(mocker, httpx_mock: HTTPXMock): + """When the server returns 401, auth_flow should force-refresh the token and retry.""" + mock_get_ansyId_token = mocker.patch( + "ansys.conceptev.core.auth.get_ansyId_token", + side_effect=["expired_token", "fresh_token"], + ) + auth_instance = auth.AnsysIDAuth() + httpx_mock.add_response(url="http://example.com", status_code=401) + httpx_mock.add_response(url="http://example.com", status_code=200) + + client = httpx.Client(auth=auth_instance) + response = client.get("http://example.com") + + assert response.status_code == 200 + assert mock_get_ansyId_token.call_count == 2 + # Second call must use force=True to bypass the MSAL cache + assert mock_get_ansyId_token.call_args_list[1] == mocker.call(auth_instance.app, force=True) + assert response.request.headers["Authorization"] == "fresh_token" + + +def test_auth_flow_does_not_retry_on_other_errors(mocker, httpx_mock: HTTPXMock): + """Non-401 errors should not trigger a token refresh retry.""" + mock_get_ansyId_token = mocker.patch( + "ansys.conceptev.core.auth.get_ansyId_token", return_value="auth_class_token" + ) + auth_instance = auth.AnsysIDAuth() + httpx_mock.add_response(url="http://example.com", status_code=403) + + client = httpx.Client(auth=auth_instance) + response = client.get("http://example.com") + + assert response.status_code == 403 + assert mock_get_ansyId_token.call_count == 1 + + +def test_auth_flow_token_expires_mid_sequence(mocker, httpx_mock: HTTPXMock): + """When a token expires mid-sequence, the failing request retries with a fresh token + and all subsequent requests in the sequence also use the fresh token.""" + # Requests 1 & 2 succeed with the original token. + # Request 3 gets a 401 (token just expired), then retries with a fresh token. + # Request 4 should use the now-cached fresh token. + token_sequence = [ + "original_token", # req 1 – initial fetch + "original_token", # req 2 – initial fetch + "original_token", # req 3 – initial fetch (will be rejected) + "fresh_token", # req 3 – force-refresh after 401 + "fresh_token", # req 4 – silent fetch (MSAL cache now holds fresh token) + ] + mock_get_ansyId_token = mocker.patch( + "ansys.conceptev.core.auth.get_ansyId_token", side_effect=token_sequence + ) + auth_instance = auth.AnsysIDAuth() + + httpx_mock.add_response(url="http://example.com", status_code=200) # req 1 + httpx_mock.add_response(url="http://example.com", status_code=200) # req 2 + httpx_mock.add_response(url="http://example.com", status_code=401) # req 3 – token expired + httpx_mock.add_response(url="http://example.com", status_code=200) # req 3 – retry + httpx_mock.add_response(url="http://example.com", status_code=200) # req 4 + + client = httpx.Client(auth=auth_instance) + responses = [client.get("http://example.com") for _ in range(4)] + + assert [r.status_code for r in responses] == [200, 200, 200, 200] + assert mock_get_ansyId_token.call_count == 5 + # The force-refresh must have been triggered on the 4th call (req 3 retry) + assert mock_get_ansyId_token.call_args_list[3] == mocker.call(auth_instance.app, force=True) diff --git a/tests/test_progress.py b/tests/test_progress.py index 941957a1..72799c46 100644 --- a/tests/test_progress.py +++ b/tests/test_progress.py @@ -150,3 +150,44 @@ def test_ssl_cert_custom(): ssl_context = generate_ssl_context() assert ssl_context is not None assert ssl_context.verify_mode == ssl.CERT_REQUIRED + + +@pytest.mark.asyncio +async def test_token_refreshed_on_websocket_reconnect(): + """When a long-running job causes a WebSocket disconnection, a fresh token should + be fetched and used when reconnecting — simulating a mid-job token expiry.""" + job_id = "test_job" + user_id = "test_user" + initial_token = "initial_token" + refreshed_token = "refreshed_token" + app = PublicClientApplication("123") + + progress_message = json.dumps({"jobId": job_id, "messagetype": "progress", "progress": 50}) + complete_message = json.dumps( + {"jobId": job_id, "messagetype": "status", "status": STATUS_COMPLETE} + ) + + # First connection: delivers a progress message then disconnects (simulates token expiry). + # Second connection: delivers the completion message. + connection_calls = [] + + def fake_connect_to_ocm(uid, token): + connection_calls.append(token) + if len(connection_calls) == 1: + return AsyncContextManager([progress_message]) + return AsyncContextManager([complete_message]) + + with patch( + "ansys.conceptev.core.progress.connect_to_ocm", side_effect=fake_connect_to_ocm + ), patch( + "ansys.conceptev.core.progress.get_ansyId_token", return_value=refreshed_token + ) as mock_refresh: + result = await monitor_job_messages(job_id, user_id, initial_token, app) + + assert result == STATUS_COMPLETE + # First connection used the original token passed in + assert connection_calls[0] == initial_token + # Second connection used the refreshed token + assert connection_calls[1] == refreshed_token + # get_ansyId_token was called once to refresh after the first WebSocket disconnected + mock_refresh.assert_called_once_with(app)