Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 46 additions & 29 deletions packages/client-python/src/rocketride/core/transport_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ def set_uri(self, uri: str) -> None:
"""Update connection URI. Takes effect on the next connect()."""
self._uri = uri

def _on_message_task_done(self, task: asyncio.Task) -> None:
"""Handle completion of a message processing task."""
self._message_tasks.discard(task)

def _is_fastapi_websocket(self) -> bool:
"""
Check if current websocket is a FastAPI WebSocket instance.
Expand All @@ -147,6 +151,35 @@ def _is_fastapi_websocket(self) -> bool:
return False
return isinstance(self._websocket, WebSocket)

async def _cleanup_and_disconnect(self, reason: str, has_error: bool) -> None:
"""
Stop accepting new messages, cancel in-flight tasks, then notify disconnection.

Ensures _receive_data cannot interleave with _transport_disconnected callback
by setting _connected=False first and waiting for all tasks to complete.

Args:
reason: Reason for disconnection
has_error: Whether this was an error disconnection
"""
if not self._connected: # Already disconnected, skip
return

# Stop accepting new messages immediately
self._connected = False
Comment thread
asclearuc marked this conversation as resolved.

# Cancel and await all in-flight message tasks
if self._message_tasks:
tasks_to_cancel = [t for t in self._message_tasks if not t.done()]
for task in tasks_to_cancel:
task.cancel()
if tasks_to_cancel:
await asyncio.gather(*tasks_to_cancel, return_exceptions=True)
self._message_tasks.clear()

# Now safe to notify disconnection
await self._transport_disconnected(reason, has_error)

async def _receive_data(self, data: Union[str, bytes]) -> None:
"""
Process raw WebSocket data into structured messages.
Expand Down Expand Up @@ -243,30 +276,25 @@ async def _receive_loop(self) -> None:
# Process each message in its own task so others can be processed concurrently
task = asyncio.create_task(self._receive_data(data))
self._message_tasks.add(task)
task.add_done_callback(lambda t: self._message_tasks.discard(t))
task.add_done_callback(self._on_message_task_done)

except asyncio.CancelledError:
self._debug_message('WebSocket receive loop cancelled')
self._connected = False
await self._transport_disconnected('Cancelled by application', has_error=False)
await self._cleanup_and_disconnect('Cancelled by application', has_error=False)

except Exception as e:
# Handle different exception types
if fastapi and WebSocketDisconnect and isinstance(e, WebSocketDisconnect):
self._connected = False
await self._transport_disconnected('Connection closed', has_error=False)
await self._cleanup_and_disconnect('Connection closed', has_error=False)

elif websockets and hasattr(websockets.exceptions, 'ConnectionClosed') and isinstance(e, websockets.exceptions.ConnectionClosed):
self._connected = False
await self._transport_disconnected('Connection closed', has_error=False)
await self._cleanup_and_disconnect('Connection closed', has_error=False)

elif isinstance(e, (ConnectionResetError, ConnectionAbortedError)):
self._connected = False
await self._transport_disconnected(f'Connection error: {e}', has_error=True)
await self._cleanup_and_disconnect(f'Connection error: {e}', has_error=True)

else:
self._connected = False
await self._transport_disconnected(f'Unexpected error: {e}', has_error=True)
await self._cleanup_and_disconnect(f'Unexpected error: {e}', has_error=True)

finally:
self._connected = False
Expand Down Expand Up @@ -404,8 +432,6 @@ async def disconnect(self, reason: Optional[str] = None, has_error: bool = False
if not self._connected or not self._websocket:
return

callback_called = False

try:
self._debug_message('Gracefully disconnecting WebSocket')

Expand All @@ -424,35 +450,26 @@ async def disconnect(self, reason: Optional[str] = None, has_error: bool = False

self._debug_message('WebSocket disconnected successfully')

# Notify about disconnection (use caller-provided reason/has_error when given)
await self._transport_disconnected(reason or 'Disconnected by request', has_error)
callback_called = True
# Stop accepting new messages and cancel in-flight message tasks before notifying
await self._cleanup_and_disconnect(reason or 'Disconnected by request', has_error)

except asyncio.TimeoutError:
self._debug_message('Timeout during disconnect - forcing close')
if not callback_called:
await self._transport_disconnected('Disconnect timeout', has_error=True)
callback_called = True
await self._cleanup_and_disconnect('Disconnect timeout', has_error=True)

except (ConnectionResetError, ConnectionAbortedError):
self._debug_message('Connection closed by peer during disconnect')
if not callback_called:
await self._transport_disconnected('Connection closed by peer', has_error=False)
callback_called = True
await self._cleanup_and_disconnect('Connection closed by peer', has_error=False)

except Exception as e:
self._debug_message(f'Error during disconnect: {e}')
if not callback_called:
await self._transport_disconnected(f'Disconnect error: {e}', has_error=True)
callback_called = True
await self._cleanup_and_disconnect(f'Disconnect error: {e}', has_error=True)

finally:
# Always clean up resources
self._connected = False
# Always clean up remaining resources
self._websocket = None
self._receive_task = None
if hasattr(self, '_message_tasks'):
self._message_tasks.clear()
self._message_tasks.clear()

async def send(self, message: Dict[str, Any]) -> None:
"""
Expand Down
Loading