Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 54 additions & 3 deletions src/modelAccessors/openai_accessor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from os import environ
from typing import Any, Optional

Expand Down Expand Up @@ -47,12 +48,15 @@ def call_model(
{"role": "user", "content": prompt},
]

# Ensure the schema is compatible with OpenAI's requirements
openai_schema = self._prepare_schema_for_openai(schema)

kwargs = {
"model": model,
"messages": messages,
"response_format": {
"type": "json_schema",
"json_schema": {"name": "response", "schema": schema, "strict": True},
"json_schema": {"name": "response", "schema": openai_schema, "strict": True},
},
}

Expand All @@ -65,17 +69,64 @@ def call_model(
# Use parsed response when available
parsed = getattr(message, "parsed", None)
if parsed is not None:
return adapter.validate_python(parsed)
return adapter.validate_python(self._extract_response_from_openai_format(parsed, schema))

# Fallback to JSON parsing if parsed is not available
raw = message.content
if not raw:
raise ValueError("No content in response")
return adapter.validate_json(raw)

# Parse and extract response
parsed_json = json.loads(raw)
extracted_response = self._extract_response_from_openai_format(parsed_json, schema)
return adapter.validate_json(json.dumps(extracted_response))

def supports_tools(self, model: str) -> bool:
"""Check if model supports native tools/function calling"""
return model in self.tool_supported_models

def _prepare_schema_for_openai(self, schema: dict) -> dict:
"""
Prepare schema for OpenAI's structured output requirements.

OpenAI requires the root schema to have 'type': 'object', but Pydantic's
discriminated unions generate schemas with oneOf at the root level.
"""
# Check if the schema already has a root type of "object"
if schema.get("type") == "object":
return schema

# If it's a oneOf/anyOf schema (discriminated union), wrap it in an object
if "oneOf" in schema or "anyOf" in schema:
return {
"type": "object",
"properties": {
"response": schema
},
"required": ["response"],
"additionalProperties": False
}

return schema

def _extract_response_from_openai_format(self, response_data: Any, original_schema: dict) -> Any:
"""
Extract the actual response from OpenAI's format, handling wrapped schemas.

If the original schema was wrapped (oneOf/anyOf), unwrap the response.
Otherwise, return the response as-is.
"""
# Determine if the original schema was wrapped
schema_was_wrapped = (
original_schema.get("type") != "object" and
("oneOf" in original_schema or "anyOf" in original_schema)
)

# If schema was wrapped and response has the wrapped structure, unwrap it
if schema_was_wrapped and isinstance(response_data, dict) and "response" in response_data:
return response_data["response"]

return response_data

def _convert_to_openai_tools(self, tools: list[Tool]) -> list[dict[str, Any]]:
"""Convert our Tool objects to OpenAI's tool format"""
Expand Down
121 changes: 119 additions & 2 deletions tests/modelAccessors/test_openai_accessor.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Test OpenAI accessor model validation and structured output support."""

import json
import os
from typing import cast
from unittest.mock import Mock, patch

from pydantic import TypeAdapter

from src.modelAccessors.openai_accessor import OpenAIAccessor
from src.dataModel.model_response import ImplementedResponse
from src.dataModel.model_response import ClarifierResponse, ImplementedResponse, ModelResponse


def test_supported_model_validation():
Expand Down Expand Up @@ -118,4 +120,119 @@ def test_tool_support():
assert accessor.supports_tools("gpt-5-nano")

# Old models shouldn't be in tool support either
assert not accessor.supports_tools("gpt-4")
assert not accessor.supports_tools("gpt-4")


def test_prepare_schema_for_openai():
"""Test that discriminated union schemas are properly wrapped for OpenAI compatibility."""
# Test with discriminated union (should be wrapped)
adapter = cast(TypeAdapter[ModelResponse], TypeAdapter(ClarifierResponse))
union_schema = adapter.json_schema()

# Create accessor without initializing the client
with patch('src.modelAccessors.openai_accessor.OpenAI'):
accessor = OpenAIAccessor()

# Test union schema wrapping
fixed_schema = accessor._prepare_schema_for_openai(union_schema)
assert fixed_schema["type"] == "object"
assert "properties" in fixed_schema
assert "response" in fixed_schema["properties"]
assert fixed_schema["required"] == ["response"]
assert fixed_schema["additionalProperties"] is False
assert fixed_schema["properties"]["response"] == union_schema

# Test object schema (should not be wrapped)
object_schema = {
"type": "object",
"properties": {"test": {"type": "string"}},
"required": ["test"]
}
unchanged_schema = accessor._prepare_schema_for_openai(object_schema)
assert unchanged_schema == object_schema


def test_extract_response_from_openai_format():
"""Test that responses are properly extracted from OpenAI format."""
with patch('src.modelAccessors.openai_accessor.OpenAI'):
accessor = OpenAIAccessor()

# Test with discriminated union schema (should be unwrapped)
adapter = cast(TypeAdapter[ModelResponse], TypeAdapter(ClarifierResponse))
union_schema = adapter.json_schema()

wrapped_response = {"response": {"type": "implemented", "content": "test"}}
extracted = accessor._extract_response_from_openai_format(wrapped_response, union_schema)
assert extracted == {"type": "implemented", "content": "test"}

# Test with object schema (should not be unwrapped)
object_schema = {
"type": "object",
"properties": {"test": {"type": "string"}},
"required": ["test"]
}
object_response = {"test": "value"}
unchanged = accessor._extract_response_from_openai_format(object_response, object_schema)
assert unchanged == {"test": "value"}


def test_response_unwrapping_integration():
"""Test that wrapped responses are properly unwrapped in the full call flow."""
# Mock the OpenAI client and response
with patch('src.modelAccessors.openai_accessor.OpenAI') as mock_openai:
accessor = OpenAIAccessor()
mock_client = Mock()
mock_openai.return_value = mock_client
accessor.client = mock_client

# Create mock response with wrapped data
mock_message = Mock()
mock_message.parsed = {"response": {"type": "implemented", "content": "test"}}
mock_response = Mock()
mock_response.choices = [Mock(message=mock_message)]
mock_client.chat.completions.create.return_value = mock_response

# Test the call
adapter = cast(TypeAdapter[ModelResponse], TypeAdapter(ClarifierResponse))
schema = adapter.json_schema()

result = accessor.call_model(
prompt="test",
adapter=adapter,
schema=schema
)

# Verify the response was unwrapped correctly
assert hasattr(result, 'type')
assert result.type == "implemented"


def test_json_fallback_unwrapping_integration():
"""Test that JSON fallback also unwraps responses correctly."""
with patch('src.modelAccessors.openai_accessor.OpenAI') as mock_openai:
accessor = OpenAIAccessor()
mock_client = Mock()
mock_openai.return_value = mock_client
accessor.client = mock_client

# Create mock response with JSON content (no parsed attribute)
mock_message = Mock()
mock_message.parsed = None
mock_message.content = json.dumps({"response": {"type": "implemented", "content": "test"}})
mock_response = Mock()
mock_response.choices = [Mock(message=mock_message)]
mock_client.chat.completions.create.return_value = mock_response

# Test the call
adapter = cast(TypeAdapter[ModelResponse], TypeAdapter(ClarifierResponse))
schema = adapter.json_schema()

result = accessor.call_model(
prompt="test",
adapter=adapter,
schema=schema
)

# Verify the response was unwrapped correctly
assert hasattr(result, 'type')
assert result.type == "implemented"