Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 142 additions & 25 deletions src/modelAccessors/openai_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,43 +89,160 @@ def _prepare_schema_for_openai(self, schema: dict) -> dict:
"""
Prepare schema for OpenAI's structured output requirements.

OpenAI requires the root schema to have 'type': 'object', but Pydantic's
discriminated unions generate schemas with oneOf at the root level.
OpenAI requires the root schema to have 'type': 'object' and does not support
oneOf/anyOf anywhere in the schema. Pydantic's discriminated unions generate
schemas with oneOf at the root level, and nullable fields use anyOf, so we
need to flatten and clean them.
"""
# Check if the schema already has a root type of "object"
if schema.get("type") == "object":
# Check if the schema already has a root type of "object" and no oneOf/anyOf
if schema.get("type") == "object" and not self._contains_oneof_anyof(schema):
return schema

# If it's a oneOf/anyOf schema (discriminated union), wrap it in an object
# If it's a oneOf/anyOf schema (discriminated union), flatten it
if "oneOf" in schema or "anyOf" in schema:
return {
"type": "object",
"properties": {
"response": schema
},
"required": ["response"],
"additionalProperties": False
flattened = self._flatten_discriminated_union(schema)
else:
flattened = schema

# Clean any remaining oneOf/anyOf structures (like nullable fields)
return self._clean_oneof_anyof_recursive(flattened)

def _contains_oneof_anyof(self, obj) -> bool:
"""Recursively check if an object contains oneOf or anyOf."""
if isinstance(obj, dict):
if "oneOf" in obj or "anyOf" in obj:
return True
for value in obj.values():
if isinstance(value, (dict, list)) and self._contains_oneof_anyof(value):
return True
elif isinstance(obj, list):
for item in obj:
if isinstance(item, (dict, list)) and self._contains_oneof_anyof(item):
return True
return False

def _clean_oneof_anyof_recursive(self, obj):
"""Recursively clean oneOf/anyOf structures from a schema."""
if isinstance(obj, dict):
# Handle nullable fields (anyOf with null)
if "anyOf" in obj:
any_of = obj["anyOf"]
# Check if this is a nullable field pattern: [{"type": "string"}, {"type": "null"}]
if len(any_of) == 2:
types = []
for item in any_of:
if isinstance(item, dict) and "type" in item:
types.append(item["type"])
if "null" in types:
# Convert to OpenAI-compatible nullable field
non_null_type = [t for t in types if t != "null"][0]
other_schema = next(item for item in any_of if item.get("type") == non_null_type)
# Create a new dict without anyOf
cleaned = {k: v for k, v in obj.items() if k != "anyOf"}
cleaned.update(other_schema)
# Make it nullable by not including it in required fields
return self._clean_oneof_anyof_recursive(cleaned)

# Handle oneOf (shouldn't happen after flattening, but just in case)
if "oneOf" in obj:
# This is a complex case - for now, take the first option
# In practice, this shouldn't happen after proper flattening
one_of = obj["oneOf"]
if one_of:
cleaned = {k: v for k, v in obj.items() if k != "oneOf"}
cleaned.update(one_of[0])
return self._clean_oneof_anyof_recursive(cleaned)

# Recursively clean all nested objects
return {k: self._clean_oneof_anyof_recursive(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [self._clean_oneof_anyof_recursive(item) for item in obj]
else:
return obj

def _flatten_discriminated_union(self, schema: dict) -> dict:
"""
Flatten a discriminated union schema to be compatible with OpenAI's structured outputs.

This converts a oneOf schema into a single object schema with all possible properties
from all union members, making conditional fields optional.
"""
# Get the oneOf alternatives
one_of_schemas = schema.get("oneOf", schema.get("anyOf", []))
if not one_of_schemas:
return schema

# Get definitions from the original schema
definitions = schema.get("$defs", {})

# Collect all properties and required fields from all alternatives
all_properties = {}
required_fields = set()

# The discriminator field (usually 'type') is required
discriminator_info = schema.get("discriminator", {})
discriminator_field = discriminator_info.get("propertyName")

if discriminator_field:
required_fields.add(discriminator_field)
# Create an enum for all possible discriminator values
possible_values = []
for alt_schema in one_of_schemas:
# Resolve $ref if present
resolved_schema = self._resolve_schema_ref(alt_schema, definitions)
properties = resolved_schema.get("properties", {})
if discriminator_field in properties:
disc_prop = properties[discriminator_field]
if "const" in disc_prop:
possible_values.append(disc_prop["const"])
elif "enum" in disc_prop:
possible_values.extend(disc_prop["enum"])

all_properties[discriminator_field] = {
"type": "string",
"enum": possible_values
}

# Merge properties from all alternatives
for alt_schema in one_of_schemas:
resolved_schema = self._resolve_schema_ref(alt_schema, definitions)
properties = resolved_schema.get("properties", {})

for prop_name, prop_schema in properties.items():
if prop_name not in all_properties:
all_properties[prop_name] = prop_schema
# Note: We don't make alternative-specific fields required in the flattened schema
# because they're only required conditionally based on the discriminator value

flattened = {
"type": "object",
"properties": all_properties,
"required": list(required_fields),
"additionalProperties": False
}

# Include definitions if they exist and are referenced
if definitions:
flattened["$defs"] = definitions

return flattened

def _resolve_schema_ref(self, schema: dict, definitions: dict) -> dict:
"""Resolve a $ref reference to get the actual schema definition."""
if "$ref" in schema:
ref_path = schema["$ref"]
if ref_path.startswith("#/$defs/"):
def_name = ref_path.replace("#/$defs/", "")
return definitions.get(def_name, schema)
return schema

def _extract_response_from_openai_format(self, response_data: Any, original_schema: dict) -> Any:
"""
Extract the actual response from OpenAI's format, handling wrapped schemas.
Extract the actual response from OpenAI's format.

If the original schema was wrapped (oneOf/anyOf), unwrap the response.
Otherwise, return the response as-is.
Since we now flatten discriminated unions instead of wrapping them,
we can return the response data as-is.
"""
# Determine if the original schema was wrapped
schema_was_wrapped = (
original_schema.get("type") != "object" and
("oneOf" in original_schema or "anyOf" in original_schema)
)

# If schema was wrapped and response has the wrapped structure, unwrap it
if schema_was_wrapped and isinstance(response_data, dict) and "response" in response_data:
return response_data["response"]

return response_data

def _convert_to_openai_tools(self, tools: list[Tool]) -> list[dict[str, Any]]:
Expand Down
85 changes: 66 additions & 19 deletions tests/modelAccessors/test_openai_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,25 +124,37 @@ def test_tool_support():


def test_prepare_schema_for_openai():
"""Test that discriminated union schemas are properly wrapped for OpenAI compatibility."""
# Test with discriminated union (should be wrapped)
"""Test that discriminated union schemas are properly flattened for OpenAI compatibility."""
# Test with discriminated union (should be flattened)
adapter = cast(TypeAdapter[ModelResponse], TypeAdapter(ClarifierResponse))
union_schema = adapter.json_schema()

# Create accessor without initializing the client
with patch('src.modelAccessors.openai_accessor.OpenAI'):
accessor = OpenAIAccessor()

# Test union schema wrapping
# Test union schema flattening
fixed_schema = accessor._prepare_schema_for_openai(union_schema)
assert fixed_schema["type"] == "object"
assert "properties" in fixed_schema
assert "response" in fixed_schema["properties"]
assert fixed_schema["required"] == ["response"]
assert "oneOf" not in fixed_schema
assert "anyOf" not in fixed_schema
assert fixed_schema["required"] == ["type"]
assert fixed_schema["additionalProperties"] is False
assert fixed_schema["properties"]["response"] == union_schema

# Test object schema (should not be wrapped)
# Should have properties from both FollowUpResponse and ImplementedResponse
properties = fixed_schema["properties"]
assert "type" in properties
assert "content" in properties # Common field
assert "artifacts" in properties # Common field
assert "follow_up_ask" in properties # FollowUpResponse field

# Type field should have correct enum values
type_prop = properties["type"]
assert type_prop["type"] == "string"
assert set(type_prop["enum"]) == {"follow_up_required", "implemented"}

# Test object schema (should not be changed)
object_schema = {
"type": "object",
"properties": {"test": {"type": "string"}},
Expand All @@ -157,15 +169,15 @@ def test_extract_response_from_openai_format():
with patch('src.modelAccessors.openai_accessor.OpenAI'):
accessor = OpenAIAccessor()

# Test with discriminated union schema (should be unwrapped)
# Test with discriminated union schema (should be returned as-is)
adapter = cast(TypeAdapter[ModelResponse], TypeAdapter(ClarifierResponse))
union_schema = adapter.json_schema()

wrapped_response = {"response": {"type": "implemented", "content": "test"}}
extracted = accessor._extract_response_from_openai_format(wrapped_response, union_schema)
response = {"type": "implemented", "content": "test"}
extracted = accessor._extract_response_from_openai_format(response, union_schema)
assert extracted == {"type": "implemented", "content": "test"}

# Test with object schema (should not be unwrapped)
# Test with object schema (should be returned as-is)
object_schema = {
"type": "object",
"properties": {"test": {"type": "string"}},
Expand All @@ -177,17 +189,17 @@ def test_extract_response_from_openai_format():


def test_response_unwrapping_integration():
"""Test that wrapped responses are properly unwrapped in the full call flow."""
"""Test that flattened responses are properly handled in the full call flow."""
# Mock the OpenAI client and response
with patch('src.modelAccessors.openai_accessor.OpenAI') as mock_openai:
accessor = OpenAIAccessor()
mock_client = Mock()
mock_openai.return_value = mock_client
accessor.client = mock_client

# Create mock response with wrapped data
# Create mock response with flattened data (no wrapping)
mock_message = Mock()
mock_message.parsed = {"response": {"type": "implemented", "content": "test"}}
mock_message.parsed = {"type": "implemented", "content": "test"}
mock_response = Mock()
mock_response.choices = [Mock(message=mock_message)]
mock_client.chat.completions.create.return_value = mock_response
Expand All @@ -202,13 +214,13 @@ def test_response_unwrapping_integration():
schema=schema
)

# Verify the response was unwrapped correctly
# Verify the response was handled correctly
assert hasattr(result, 'type')
assert result.type == "implemented"


def test_json_fallback_unwrapping_integration():
"""Test that JSON fallback also unwraps responses correctly."""
"""Test that JSON fallback also handles flattened responses correctly."""
with patch('src.modelAccessors.openai_accessor.OpenAI') as mock_openai:
accessor = OpenAIAccessor()
mock_client = Mock()
Expand All @@ -218,7 +230,7 @@ def test_json_fallback_unwrapping_integration():
# Create mock response with JSON content (no parsed attribute)
mock_message = Mock()
mock_message.parsed = None
mock_message.content = json.dumps({"response": {"type": "implemented", "content": "test"}})
mock_message.content = json.dumps({"type": "implemented", "content": "test"})
mock_response = Mock()
mock_response.choices = [Mock(message=mock_message)]
mock_client.chat.completions.create.return_value = mock_response
Expand All @@ -233,6 +245,41 @@ def test_json_fallback_unwrapping_integration():
schema=schema
)

# Verify the response was unwrapped correctly
# Verify the response was handled correctly
assert hasattr(result, 'type')
assert result.type == "implemented"
assert result.type == "implemented"

def test_full_modelresponse_compatibility():
"""Test that the full ModelResponse schema is OpenAI compatible after flattening."""
# This tests the exact scenario from the GitHub issue
adapter = TypeAdapter(ModelResponse)
schema = adapter.json_schema()

with patch('src.modelAccessors.openai_accessor.OpenAI'):
accessor = OpenAIAccessor()

# The original schema has oneOf which OpenAI doesn't support
assert "oneOf" in schema

# After preparation, should be fully OpenAI compatible
openai_schema = accessor._prepare_schema_for_openai(schema)
assert openai_schema["type"] == "object"
assert "oneOf" not in str(openai_schema)
assert "anyOf" not in str(openai_schema)
assert openai_schema["required"] == ["type"]

# Should have all possible properties from all union members
properties = openai_schema["properties"]
expected_props = ["type", "content", "artifacts", "subtasks", "follow_up_ask", "error_message", "retryable"]
for prop in expected_props:
assert prop in properties, f"Missing property: {prop}"

# Type field should have all discriminator values
type_prop = properties["type"]
assert type_prop["type"] == "string"
expected_types = {"decomposed", "implemented", "follow_up_required", "failed"}
assert set(type_prop["enum"]) == expected_types

# Should preserve complex type references
assert "$defs" in openai_schema
assert "Task" in openai_schema["$defs"]