diff --git a/src/modelAccessors/openai_accessor.py b/src/modelAccessors/openai_accessor.py index e72c79c..8e06a03 100644 --- a/src/modelAccessors/openai_accessor.py +++ b/src/modelAccessors/openai_accessor.py @@ -1,3 +1,4 @@ +import json from os import environ from typing import Any, Optional @@ -47,12 +48,15 @@ def call_model( {"role": "user", "content": prompt}, ] + # Ensure the schema is compatible with OpenAI's requirements + openai_schema = self._prepare_schema_for_openai(schema) + kwargs = { "model": model, "messages": messages, "response_format": { "type": "json_schema", - "json_schema": {"name": "response", "schema": schema, "strict": True}, + "json_schema": {"name": "response", "schema": openai_schema, "strict": True}, }, } @@ -65,17 +69,64 @@ def call_model( # Use parsed response when available parsed = getattr(message, "parsed", None) if parsed is not None: - return adapter.validate_python(parsed) + return adapter.validate_python(self._extract_response_from_openai_format(parsed, schema)) # Fallback to JSON parsing if parsed is not available raw = message.content if not raw: raise ValueError("No content in response") - return adapter.validate_json(raw) + + # Parse and extract response + parsed_json = json.loads(raw) + extracted_response = self._extract_response_from_openai_format(parsed_json, schema) + return adapter.validate_json(json.dumps(extracted_response)) def supports_tools(self, model: str) -> bool: """Check if model supports native tools/function calling""" return model in self.tool_supported_models + + def _prepare_schema_for_openai(self, schema: dict) -> dict: + """ + Prepare schema for OpenAI's structured output requirements. + + OpenAI requires the root schema to have 'type': 'object', but Pydantic's + discriminated unions generate schemas with oneOf at the root level. + """ + # Check if the schema already has a root type of "object" + if schema.get("type") == "object": + return schema + + # If it's a oneOf/anyOf schema (discriminated union), wrap it in an object + if "oneOf" in schema or "anyOf" in schema: + return { + "type": "object", + "properties": { + "response": schema + }, + "required": ["response"], + "additionalProperties": False + } + + return schema + + def _extract_response_from_openai_format(self, response_data: Any, original_schema: dict) -> Any: + """ + Extract the actual response from OpenAI's format, handling wrapped schemas. + + If the original schema was wrapped (oneOf/anyOf), unwrap the response. + Otherwise, return the response as-is. + """ + # Determine if the original schema was wrapped + schema_was_wrapped = ( + original_schema.get("type") != "object" and + ("oneOf" in original_schema or "anyOf" in original_schema) + ) + + # If schema was wrapped and response has the wrapped structure, unwrap it + if schema_was_wrapped and isinstance(response_data, dict) and "response" in response_data: + return response_data["response"] + + return response_data def _convert_to_openai_tools(self, tools: list[Tool]) -> list[dict[str, Any]]: """Convert our Tool objects to OpenAI's tool format""" diff --git a/tests/modelAccessors/test_openai_accessor.py b/tests/modelAccessors/test_openai_accessor.py index c642a13..e7bfa1a 100644 --- a/tests/modelAccessors/test_openai_accessor.py +++ b/tests/modelAccessors/test_openai_accessor.py @@ -1,12 +1,14 @@ """Test OpenAI accessor model validation and structured output support.""" +import json import os +from typing import cast from unittest.mock import Mock, patch from pydantic import TypeAdapter from src.modelAccessors.openai_accessor import OpenAIAccessor -from src.dataModel.model_response import ImplementedResponse +from src.dataModel.model_response import ClarifierResponse, ImplementedResponse, ModelResponse def test_supported_model_validation(): @@ -118,4 +120,119 @@ def test_tool_support(): assert accessor.supports_tools("gpt-5-nano") # Old models shouldn't be in tool support either - assert not accessor.supports_tools("gpt-4") \ No newline at end of file + assert not accessor.supports_tools("gpt-4") + + +def test_prepare_schema_for_openai(): + """Test that discriminated union schemas are properly wrapped for OpenAI compatibility.""" + # Test with discriminated union (should be wrapped) + adapter = cast(TypeAdapter[ModelResponse], TypeAdapter(ClarifierResponse)) + union_schema = adapter.json_schema() + + # Create accessor without initializing the client + with patch('src.modelAccessors.openai_accessor.OpenAI'): + accessor = OpenAIAccessor() + + # Test union schema wrapping + fixed_schema = accessor._prepare_schema_for_openai(union_schema) + assert fixed_schema["type"] == "object" + assert "properties" in fixed_schema + assert "response" in fixed_schema["properties"] + assert fixed_schema["required"] == ["response"] + assert fixed_schema["additionalProperties"] is False + assert fixed_schema["properties"]["response"] == union_schema + + # Test object schema (should not be wrapped) + object_schema = { + "type": "object", + "properties": {"test": {"type": "string"}}, + "required": ["test"] + } + unchanged_schema = accessor._prepare_schema_for_openai(object_schema) + assert unchanged_schema == object_schema + + +def test_extract_response_from_openai_format(): + """Test that responses are properly extracted from OpenAI format.""" + with patch('src.modelAccessors.openai_accessor.OpenAI'): + accessor = OpenAIAccessor() + + # Test with discriminated union schema (should be unwrapped) + adapter = cast(TypeAdapter[ModelResponse], TypeAdapter(ClarifierResponse)) + union_schema = adapter.json_schema() + + wrapped_response = {"response": {"type": "implemented", "content": "test"}} + extracted = accessor._extract_response_from_openai_format(wrapped_response, union_schema) + assert extracted == {"type": "implemented", "content": "test"} + + # Test with object schema (should not be unwrapped) + object_schema = { + "type": "object", + "properties": {"test": {"type": "string"}}, + "required": ["test"] + } + object_response = {"test": "value"} + unchanged = accessor._extract_response_from_openai_format(object_response, object_schema) + assert unchanged == {"test": "value"} + + +def test_response_unwrapping_integration(): + """Test that wrapped responses are properly unwrapped in the full call flow.""" + # Mock the OpenAI client and response + with patch('src.modelAccessors.openai_accessor.OpenAI') as mock_openai: + accessor = OpenAIAccessor() + mock_client = Mock() + mock_openai.return_value = mock_client + accessor.client = mock_client + + # Create mock response with wrapped data + mock_message = Mock() + mock_message.parsed = {"response": {"type": "implemented", "content": "test"}} + mock_response = Mock() + mock_response.choices = [Mock(message=mock_message)] + mock_client.chat.completions.create.return_value = mock_response + + # Test the call + adapter = cast(TypeAdapter[ModelResponse], TypeAdapter(ClarifierResponse)) + schema = adapter.json_schema() + + result = accessor.call_model( + prompt="test", + adapter=adapter, + schema=schema + ) + + # Verify the response was unwrapped correctly + assert hasattr(result, 'type') + assert result.type == "implemented" + + +def test_json_fallback_unwrapping_integration(): + """Test that JSON fallback also unwraps responses correctly.""" + with patch('src.modelAccessors.openai_accessor.OpenAI') as mock_openai: + accessor = OpenAIAccessor() + mock_client = Mock() + mock_openai.return_value = mock_client + accessor.client = mock_client + + # Create mock response with JSON content (no parsed attribute) + mock_message = Mock() + mock_message.parsed = None + mock_message.content = json.dumps({"response": {"type": "implemented", "content": "test"}}) + mock_response = Mock() + mock_response.choices = [Mock(message=mock_message)] + mock_client.chat.completions.create.return_value = mock_response + + # Test the call + adapter = cast(TypeAdapter[ModelResponse], TypeAdapter(ClarifierResponse)) + schema = adapter.json_schema() + + result = accessor.call_model( + prompt="test", + adapter=adapter, + schema=schema + ) + + # Verify the response was unwrapped correctly + assert hasattr(result, 'type') + assert result.type == "implemented" \ No newline at end of file