Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changelog.d/325.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix/token timeout 2
2 changes: 1 addition & 1 deletion src/ansys/conceptev/core/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def get_http_client(
client.send = retry(
retry=retry_if_result(is_gateway_error),
wait=wait_random_exponential(multiplier=1, max=60),
stop=stop_after_delay(10),
stop=stop_after_delay(120),
)(client.send)
return client

Expand Down
7 changes: 6 additions & 1 deletion src/ansys/conceptev/core/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,12 @@ def auth_flow(self, request):
"""Send the request, with a custom `Authentication` header."""
token = get_ansyId_token(self.app)
request.headers["Authorization"] = token
yield request
response = yield request
if response.status_code == 401:
logger.info("Token expired or rejected (401). Refreshing token and retrying.")
token = get_ansyId_token(self.app, force=True)
request.headers["Authorization"] = token
yield request


def get_token(client: httpx.Client) -> str:
Expand Down
3 changes: 2 additions & 1 deletion src/ansys/conceptev/core/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from msal import PublicClientApplication
from websockets.asyncio.client import connect

from ansys.conceptev.core.auth import get_ansyId_token
from ansys.conceptev.core.settings import settings

if sys.version_info >= (3, 11):
Expand Down Expand Up @@ -163,7 +164,7 @@ def monitor_job_progress(
if __name__ == "__main__":
"""Monitor a single job progress."""
from ansys.conceptev.core.app import get_user_id
from ansys.conceptev.core.auth import create_msal_app, get_ansyId_token
from ansys.conceptev.core.auth import create_msal_app

job_id = "ae3f3b4b-91d8-4cdd-8fa3-25eb202a561e" # Replace with your job ID
msal_app = create_msal_app()
Expand Down
68 changes: 68 additions & 0 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,71 @@ def test_auth_flow_adds_authorization_header(mocker, httpx_mock: HTTPXMock):
client = httpx.Client(auth=auth_instance)
response = client.get("http://example.com")
assert response.request.headers["Authorization"] == "auth_class_token"


def test_auth_flow_retries_on_401_with_fresh_token(mocker, httpx_mock: HTTPXMock):
"""When the server returns 401, auth_flow should force-refresh the token and retry."""
mock_get_ansyId_token = mocker.patch(
"ansys.conceptev.core.auth.get_ansyId_token",
side_effect=["expired_token", "fresh_token"],
)
auth_instance = auth.AnsysIDAuth()
httpx_mock.add_response(url="http://example.com", status_code=401)
httpx_mock.add_response(url="http://example.com", status_code=200)

client = httpx.Client(auth=auth_instance)
response = client.get("http://example.com")

assert response.status_code == 200
assert mock_get_ansyId_token.call_count == 2
# Second call must use force=True to bypass the MSAL cache
assert mock_get_ansyId_token.call_args_list[1] == mocker.call(auth_instance.app, force=True)
assert response.request.headers["Authorization"] == "fresh_token"


def test_auth_flow_does_not_retry_on_other_errors(mocker, httpx_mock: HTTPXMock):
"""Non-401 errors should not trigger a token refresh retry."""
mock_get_ansyId_token = mocker.patch(
"ansys.conceptev.core.auth.get_ansyId_token", return_value="auth_class_token"
)
auth_instance = auth.AnsysIDAuth()
httpx_mock.add_response(url="http://example.com", status_code=403)

client = httpx.Client(auth=auth_instance)
response = client.get("http://example.com")

assert response.status_code == 403
assert mock_get_ansyId_token.call_count == 1


def test_auth_flow_token_expires_mid_sequence(mocker, httpx_mock: HTTPXMock):
"""When a token expires mid-sequence, the failing request retries with a fresh token
and all subsequent requests in the sequence also use the fresh token."""
# Requests 1 & 2 succeed with the original token.
# Request 3 gets a 401 (token just expired), then retries with a fresh token.
# Request 4 should use the now-cached fresh token.
token_sequence = [
"original_token", # req 1 – initial fetch
"original_token", # req 2 – initial fetch
"original_token", # req 3 – initial fetch (will be rejected)
"fresh_token", # req 3 – force-refresh after 401
"fresh_token", # req 4 – silent fetch (MSAL cache now holds fresh token)
]
mock_get_ansyId_token = mocker.patch(
"ansys.conceptev.core.auth.get_ansyId_token", side_effect=token_sequence
)
auth_instance = auth.AnsysIDAuth()

httpx_mock.add_response(url="http://example.com", status_code=200) # req 1
httpx_mock.add_response(url="http://example.com", status_code=200) # req 2
httpx_mock.add_response(url="http://example.com", status_code=401) # req 3 – token expired
httpx_mock.add_response(url="http://example.com", status_code=200) # req 3 – retry
httpx_mock.add_response(url="http://example.com", status_code=200) # req 4

client = httpx.Client(auth=auth_instance)
responses = [client.get("http://example.com") for _ in range(4)]

assert [r.status_code for r in responses] == [200, 200, 200, 200]
assert mock_get_ansyId_token.call_count == 5
# The force-refresh must have been triggered on the 4th call (req 3 retry)
assert mock_get_ansyId_token.call_args_list[3] == mocker.call(auth_instance.app, force=True)
41 changes: 41 additions & 0 deletions tests/test_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,44 @@ def test_ssl_cert_custom():
ssl_context = generate_ssl_context()
assert ssl_context is not None
assert ssl_context.verify_mode == ssl.CERT_REQUIRED


@pytest.mark.asyncio
async def test_token_refreshed_on_websocket_reconnect():
"""When a long-running job causes a WebSocket disconnection, a fresh token should
be fetched and used when reconnecting — simulating a mid-job token expiry."""
job_id = "test_job"
user_id = "test_user"
initial_token = "initial_token"
refreshed_token = "refreshed_token"
app = PublicClientApplication("123")

progress_message = json.dumps({"jobId": job_id, "messagetype": "progress", "progress": 50})
complete_message = json.dumps(
{"jobId": job_id, "messagetype": "status", "status": STATUS_COMPLETE}
)

# First connection: delivers a progress message then disconnects (simulates token expiry).
# Second connection: delivers the completion message.
connection_calls = []

def fake_connect_to_ocm(uid, token):
connection_calls.append(token)
if len(connection_calls) == 1:
return AsyncContextManager([progress_message])
return AsyncContextManager([complete_message])

with patch(
"ansys.conceptev.core.progress.connect_to_ocm", side_effect=fake_connect_to_ocm
), patch(
"ansys.conceptev.core.progress.get_ansyId_token", return_value=refreshed_token
) as mock_refresh:
result = await monitor_job_messages(job_id, user_id, initial_token, app)

assert result == STATUS_COMPLETE
# First connection used the original token passed in
assert connection_calls[0] == initial_token
# Second connection used the refreshed token
assert connection_calls[1] == refreshed_token
# get_ansyId_token was called once to refresh after the first WebSocket disconnected
mock_refresh.assert_called_once_with(app)
Loading