diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 54ee6df..486ad58 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -92,4 +92,4 @@ jobs: AMIGO_USER_ID: ${{ secrets.AMIGO_USER_ID }} AMIGO_BASE_URL: ${{ secrets.AMIGO_BASE_URL }} AMIGO_TEST_SERVICE_ID: ${{ secrets.AMIGO_TEST_SERVICE_ID }} - run: pytest -m integration -v -s + run: pytest -m integration -v -s --no-cov diff --git a/src/amigo_sdk/resources/conversation.py b/src/amigo_sdk/resources/conversation.py index 7f8e585..bed3713 100644 --- a/src/amigo_sdk/resources/conversation.py +++ b/src/amigo_sdk/resources/conversation.py @@ -10,6 +10,7 @@ from amigo_sdk.generated.model import ( ConversationCreateConversationRequest, ConversationCreateConversationResponse, + ConversationEvent, ConversationGenerateConversationStarterRequest, ConversationGenerateConversationStarterResponse, ConversationGetConversationMessagesResponse, @@ -182,10 +183,18 @@ async def _generator(): f"/v1/{self._organization_id}/conversation/{conversation_id}/interact", **request_kwargs, ): - # Each line is a JSON object representing a discriminated union event - yield ConversationInteractWithConversationResponse.model_validate_json( - line + # Each line is a JSON object representing a discriminated union event. + # The response wraps events in ConversationEvent RootModel; unwrap + # so callers get concrete event types directly. + parsed = ( + ConversationInteractWithConversationResponse.model_validate_json( + line + ) ) + event = parsed.root + if isinstance(event, ConversationEvent): + parsed.root = event.root + yield parsed return _generator() @@ -401,9 +410,15 @@ def _iter(): f"/v1/{self._organization_id}/conversation/{conversation_id}/interact", **request_kwargs, ): - yield ConversationInteractWithConversationResponse.model_validate_json( - line + parsed = ( + ConversationInteractWithConversationResponse.model_validate_json( + line + ) ) + event = parsed.root + if isinstance(event, ConversationEvent): + parsed.root = event.root + yield parsed return _iter() diff --git a/tests/integration/test_conversation_integration.py b/tests/integration/test_conversation_integration.py index 01965af..d85cf15 100644 --- a/tests/integration/test_conversation_integration.py +++ b/tests/integration/test_conversation_integration.py @@ -10,7 +10,6 @@ from amigo_sdk.generated.model import ( ConversationCreateConversationRequest, ConversationCreatedEvent, - ConversationEvent, CreateConversationParametersQuery, ErrorEvent, GetConversationMessagesParametersQuery, @@ -23,14 +22,6 @@ ) from amigo_sdk.sdk_client import AmigoClient, AsyncAmigoClient - -def _unwrap_event(e: object) -> object: - """Unwrap nested RootModel events (ConversationEvent wraps InteractionCompleteEvent etc.).""" - if isinstance(e, ConversationEvent): - return e.root - return e - - # Constants SERVICE_ID = os.getenv("AMIGO_TEST_SERVICE_ID", "66e0da39f5a09fb3cf18ea75") @@ -189,7 +180,7 @@ async def test_interact_with_conversation_text_streams(self): event_count = 0 async for evt in events: - e = _unwrap_event(evt.root) + e = evt.root event_count += 1 if isinstance(e, ErrorEvent): pytest.fail(f"error event: {e.model_dump_json()}") @@ -239,7 +230,7 @@ async def test_interact_with_conversation_external_event_streams(self): event_count = 0 async for evt in events: - e = _unwrap_event(evt.root) + e = evt.root event_count += 1 if isinstance(e, ErrorEvent): pytest.fail(f"error event: {e.model_dump_json()}") @@ -280,7 +271,7 @@ async def test_interact_with_conversation_voice_streams(self): event_count = 0 async for evt in events: - e = _unwrap_event(evt.root) + e = evt.root event_count += 1 if isinstance(e, ErrorEvent): pytest.fail(f"error event: {e.model_dump_json()}") @@ -455,7 +446,7 @@ def test_interact_with_conversation_text_streams(self): event_count = 0 for evt in events: - e = _unwrap_event(evt.root) + e = evt.root event_count += 1 if isinstance(e, ErrorEvent): pytest.fail(f"error event: {e.model_dump_json()}") @@ -505,7 +496,7 @@ def test_interact_with_conversation_external_event_streams(self): event_count = 0 for evt in events: - e = _unwrap_event(evt.root) + e = evt.root event_count += 1 if isinstance(e, ErrorEvent): pytest.fail(f"error event: {e.model_dump_json()}") @@ -546,7 +537,7 @@ def test_interact_with_conversation_voice_streams(self): event_count = 0 for evt in events: - e = _unwrap_event(evt.root) + e = evt.root event_count += 1 if isinstance(e, ErrorEvent): pytest.fail(f"error event: {e.model_dump_json()}")