Skip to content

Commit 8570421

Browse files
authored
Merge pull request #178 from AbdulmujibOladayo/main
Typing event broadcast to all participants, not persisted
2 parents 452a3cb + 54ac75e commit 8570421

4 files changed

Lines changed: 294 additions & 14 deletions

File tree

src/chat.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import asyncio
33
import json
44
from datetime import datetime
5-
from typing import Any, Dict, List, Optional, Set
5+
from typing import Any, Dict, List, Literal, Optional, Set, Union
66

77
from fastapi import WebSocket
88
from pydantic import BaseModel
@@ -27,6 +27,24 @@ class ChatMessage(BaseModel):
2727
metadata: Optional[Dict[str, Any]] = None
2828

2929

30+
class TypingEvent(BaseModel):
31+
"""Ephemeral event indicating a participant is typing."""
32+
33+
type: Literal["typing"] = "typing"
34+
sender_id: str
35+
conversation_id: str
36+
is_typing: bool
37+
38+
39+
class ReadReceiptEvent(BaseModel):
40+
"""Ephemeral event indicating messages have been read."""
41+
42+
type: Literal["read_receipt"] = "read_receipt"
43+
sender_id: str
44+
conversation_id: str
45+
last_read_message_id: str
46+
47+
3048
class EscalationEvent(BaseModel):
3149
"""Represents an escalation event."""
3250

@@ -119,6 +137,33 @@ async def send_message(self, message: ChatMessage) -> bool:
119137

120138
return True
121139

140+
async def broadcast_event(
141+
self, event: Union[TypingEvent, ReadReceiptEvent]
142+
) -> bool:
143+
"""Broadcast an ephemeral event to all participants without persisting it."""
144+
conversation_id = event.conversation_id
145+
if conversation_id not in self.active_connections:
146+
return True
147+
148+
disconnected: List[WebSocket] = []
149+
for websocket in self.active_connections[conversation_id]:
150+
try:
151+
await websocket.send_text(event.model_dump_json())
152+
except Exception as exc:
153+
log_warning(
154+
"Failed to send event to websocket",
155+
{"conversation_id": conversation_id, "error": str(exc)},
156+
)
157+
disconnected.append(websocket)
158+
159+
if disconnected:
160+
async with self._lock:
161+
for ws in disconnected:
162+
if ws in self.active_connections.get(conversation_id, []):
163+
self.active_connections[conversation_id].remove(ws)
164+
165+
return True
166+
122167
def get_message_history(
123168
self, conversation_id: str, limit: int = 50
124169
) -> List[ChatMessage]:

src/main.py

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from src.auth.dependencies import require_admin_key, require_service_key
1818

1919
from src.analytics.service import analytics_service
20-
from src.chat import ChatMessage, EscalationEvent, chat_manager
20+
from src.chat import ChatMessage, EscalationEvent, ReadReceiptEvent, TypingEvent, chat_manager
2121
from src.config import get_settings
2222
from src.core.ratelimit import limiter
2323
from src.etl import diff_etl_output, extract_events_and_sales, run_etl_once, transform_summary
@@ -79,6 +79,8 @@
7979
ChatMessageHistoryResponse,
8080
ChatMessageSendRequest,
8181
ChatMessageSendResponse,
82+
ChatTypingRequest,
83+
ChatTypingResponse,
8284
ChatUserConversationsResponse,
8385
DailyReportRequest,
8486
DailyReportResponse,
@@ -747,16 +749,32 @@ async def websocket_chat(
747749
data = await websocket.receive_text()
748750
try:
749751
message_data: Dict[str, Any] = json.loads(data)
750-
message = ChatMessage(
751-
id=str(uuid.uuid4()),
752-
sender_id=user_id,
753-
sender_type=message_data.get("sender_type", "user"),
754-
content=message_data["content"],
755-
timestamp=datetime.utcnow(),
756-
conversation_id=conversation_id,
757-
metadata=message_data.get("metadata", {}),
758-
)
759-
await chat_manager.send_message(message)
752+
msg_type = message_data.get("type")
753+
if msg_type == "typing":
754+
event = TypingEvent(
755+
sender_id=user_id,
756+
conversation_id=conversation_id,
757+
is_typing=message_data.get("is_typing", False),
758+
)
759+
await chat_manager.broadcast_event(event)
760+
elif msg_type == "read_receipt":
761+
event = ReadReceiptEvent(
762+
sender_id=user_id,
763+
conversation_id=conversation_id,
764+
last_read_message_id=message_data["last_read_message_id"],
765+
)
766+
await chat_manager.broadcast_event(event)
767+
else:
768+
message = ChatMessage(
769+
id=str(uuid.uuid4()),
770+
sender_id=user_id,
771+
sender_type=message_data.get("sender_type", "user"),
772+
content=message_data["content"],
773+
timestamp=datetime.utcnow(),
774+
conversation_id=conversation_id,
775+
metadata=message_data.get("metadata", {}),
776+
)
777+
await chat_manager.send_message(message)
760778
except json.JSONDecodeError:
761779
logger.warning("Invalid JSON received from client")
762780
except KeyError as exc:
@@ -801,6 +819,24 @@ async def send_message(
801819
raise HTTPException(status_code=500, detail="Failed to send message")
802820

803821

822+
@app.post("/chat/{conversation_id}/typing", response_model=ChatTypingResponse)
823+
async def send_typing_indicator(
824+
conversation_id: str, body: ChatTypingRequest
825+
) -> ChatTypingResponse:
826+
"""Broadcast a typing indicator to all participants in a conversation (not persisted)."""
827+
try:
828+
event = TypingEvent(
829+
sender_id=body.sender_id,
830+
conversation_id=conversation_id,
831+
is_typing=body.is_typing,
832+
)
833+
await chat_manager.broadcast_event(event)
834+
return ChatTypingResponse(status="success")
835+
except Exception as exc:
836+
logger.error("Error sending typing indicator: %s", exc)
837+
raise HTTPException(status_code=500, detail="Failed to send typing indicator")
838+
839+
804840
@app.get("/chat/{conversation_id}/history", response_model=ChatMessageHistoryResponse)
805841
async def get_message_history(
806842
conversation_id: str,

src/types_custom.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,17 @@ class ChatUserConversationsResponse(BaseModel):
216216
count: int
217217

218218

219+
class ChatTypingRequest(BaseModel):
220+
model_config = ConfigDict(extra="forbid")
221+
sender_id: str = Field(..., min_length=1)
222+
is_typing: bool
223+
224+
225+
class ChatTypingResponse(BaseModel):
226+
model_config = ConfigDict(extra="forbid")
227+
status: Literal["success"]
228+
229+
219230
class AnalyticsStatsQuery(BaseModel):
220231
model_config = ConfigDict(extra="forbid")
221232
event_id: Optional[str] = None

tests/test_chat.py

Lines changed: 190 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from unittest.mock import AsyncMock, Mock, patch
77
from fastapi.testclient import TestClient
88
from src.main import app
9-
from src.chat import ChatManager, ChatMessage, EscalationEvent
9+
from src.chat import ChatManager, ChatMessage, EscalationEvent, ReadReceiptEvent, TypingEvent
1010

1111

1212
@pytest.fixture
@@ -344,4 +344,192 @@ async def test_cleanup_on_websocket_disconnect(chat_manager):
344344
await chat_manager.disconnect(mock_websocket, conversation_id, user_id)
345345

346346
# Check cleanup
347-
assert len(chat_manager.active_connections.get(conversation_id, [])) == 0
347+
assert len(chat_manager.active_connections.get(conversation_id, [])) == 0
348+
349+
350+
# ---------------------------------------------------------------------------
351+
# Typing indicators and read receipts
352+
# ---------------------------------------------------------------------------
353+
354+
@pytest.mark.asyncio
355+
async def test_typing_event_broadcast_not_persisted(chat_manager):
356+
"""Typing events are broadcast to all participants but not stored in message history."""
357+
conversation_id = "typing_test_conv"
358+
359+
websockets = [AsyncMock() for _ in range(2)]
360+
for i, ws in enumerate(websockets):
361+
await chat_manager.connect(ws, conversation_id, f"user_{i}")
362+
363+
event = TypingEvent(
364+
sender_id="user_0",
365+
conversation_id=conversation_id,
366+
is_typing=True,
367+
)
368+
await chat_manager.broadcast_event(event)
369+
370+
# All participants should receive the event
371+
for ws in websockets:
372+
ws.send_text.assert_called()
373+
374+
# Nothing should be stored in message history
375+
assert chat_manager.message_history.get(conversation_id, []) == []
376+
377+
378+
@pytest.mark.asyncio
379+
async def test_typing_event_stop_broadcast_not_persisted(chat_manager):
380+
"""is_typing=False is broadcast and not persisted."""
381+
conversation_id = "typing_stop_conv"
382+
383+
ws = AsyncMock()
384+
await chat_manager.connect(ws, conversation_id, "user_1")
385+
386+
event = TypingEvent(
387+
sender_id="user_1",
388+
conversation_id=conversation_id,
389+
is_typing=False,
390+
)
391+
await chat_manager.broadcast_event(event)
392+
393+
ws.send_text.assert_called_once()
394+
assert chat_manager.message_history.get(conversation_id, []) == []
395+
396+
397+
@pytest.mark.asyncio
398+
async def test_read_receipt_broadcast_not_persisted(chat_manager):
399+
"""Read receipt events are broadcast to all participants but not stored in message history."""
400+
conversation_id = "receipt_test_conv"
401+
402+
websockets = [AsyncMock() for _ in range(3)]
403+
for i, ws in enumerate(websockets):
404+
await chat_manager.connect(ws, conversation_id, f"user_{i}")
405+
406+
event = ReadReceiptEvent(
407+
sender_id="user_0",
408+
conversation_id=conversation_id,
409+
last_read_message_id="msg_42",
410+
)
411+
await chat_manager.broadcast_event(event)
412+
413+
for ws in websockets:
414+
ws.send_text.assert_called()
415+
416+
assert chat_manager.message_history.get(conversation_id, []) == []
417+
418+
419+
@pytest.mark.asyncio
420+
async def test_broadcast_event_no_active_connections(chat_manager):
421+
"""broadcast_event returns True when there are no active connections."""
422+
event = TypingEvent(
423+
sender_id="user_x",
424+
conversation_id="empty_conv",
425+
is_typing=True,
426+
)
427+
result = await chat_manager.broadcast_event(event)
428+
assert result is True
429+
430+
431+
@pytest.mark.asyncio
432+
async def test_typing_event_payload_shape(chat_manager):
433+
"""The JSON sent for a typing event contains the correct fields."""
434+
conversation_id = "shape_test_conv"
435+
ws = AsyncMock()
436+
await chat_manager.connect(ws, conversation_id, "user_1")
437+
438+
event = TypingEvent(
439+
sender_id="user_1",
440+
conversation_id=conversation_id,
441+
is_typing=True,
442+
)
443+
await chat_manager.broadcast_event(event)
444+
445+
ws.send_text.assert_called_once()
446+
payload = json.loads(ws.send_text.call_args[0][0])
447+
assert payload["type"] == "typing"
448+
assert payload["sender_id"] == "user_1"
449+
assert payload["conversation_id"] == conversation_id
450+
assert payload["is_typing"] is True
451+
452+
453+
@pytest.mark.asyncio
454+
async def test_read_receipt_payload_shape(chat_manager):
455+
"""The JSON sent for a read receipt event contains the correct fields."""
456+
conversation_id = "receipt_shape_conv"
457+
ws = AsyncMock()
458+
await chat_manager.connect(ws, conversation_id, "user_2")
459+
460+
event = ReadReceiptEvent(
461+
sender_id="user_2",
462+
conversation_id=conversation_id,
463+
last_read_message_id="msg_99",
464+
)
465+
await chat_manager.broadcast_event(event)
466+
467+
ws.send_text.assert_called_once()
468+
payload = json.loads(ws.send_text.call_args[0][0])
469+
assert payload["type"] == "read_receipt"
470+
assert payload["sender_id"] == "user_2"
471+
assert payload["last_read_message_id"] == "msg_99"
472+
473+
474+
@pytest.mark.asyncio
475+
async def test_regular_message_still_persisted_after_typing(chat_manager):
476+
"""Regular chat messages are still persisted even after typing events are broadcast."""
477+
conversation_id = "mixed_conv"
478+
ws = AsyncMock()
479+
await chat_manager.connect(ws, conversation_id, "user_1")
480+
481+
# Broadcast a typing event (should not persist)
482+
typing_event = TypingEvent(
483+
sender_id="user_1",
484+
conversation_id=conversation_id,
485+
is_typing=True,
486+
)
487+
await chat_manager.broadcast_event(typing_event)
488+
489+
# Send a real message (should persist)
490+
message = ChatMessage(
491+
id="msg_real",
492+
sender_id="user_1",
493+
sender_type="user",
494+
content="Here is my message",
495+
timestamp=datetime.utcnow(),
496+
conversation_id=conversation_id,
497+
)
498+
await chat_manager.send_message(message)
499+
500+
history = chat_manager.get_message_history(conversation_id)
501+
assert len(history) == 1
502+
assert history[0].id == "msg_real"
503+
504+
505+
# ---------------------------------------------------------------------------
506+
# HTTP typing endpoint
507+
# ---------------------------------------------------------------------------
508+
509+
def test_typing_endpoint_broadcasts_and_returns_success(client):
510+
"""POST /chat/{conversation_id}/typing returns 200 with status=success."""
511+
response = client.post(
512+
"/chat/conv_typing_http/typing",
513+
json={"sender_id": "user_1", "is_typing": True},
514+
)
515+
assert response.status_code == 200
516+
assert response.json()["status"] == "success"
517+
518+
519+
def test_typing_endpoint_is_typing_false(client):
520+
"""POST /chat/{conversation_id}/typing with is_typing=False returns success."""
521+
response = client.post(
522+
"/chat/conv_typing_stop/typing",
523+
json={"sender_id": "user_2", "is_typing": False},
524+
)
525+
assert response.status_code == 200
526+
assert response.json()["status"] == "success"
527+
528+
529+
def test_typing_endpoint_missing_sender_id(client):
530+
"""POST /chat/{conversation_id}/typing with missing sender_id returns 422."""
531+
response = client.post(
532+
"/chat/conv_typing_bad/typing",
533+
json={"is_typing": True},
534+
)
535+
assert response.status_code == 422

0 commit comments

Comments
 (0)