diff --git a/src/modelAccessors/openai_accessor.py b/src/modelAccessors/openai_accessor.py index 8e06a03..37e18ce 100644 --- a/src/modelAccessors/openai_accessor.py +++ b/src/modelAccessors/openai_accessor.py @@ -89,43 +89,160 @@ def _prepare_schema_for_openai(self, schema: dict) -> dict: """ Prepare schema for OpenAI's structured output requirements. - OpenAI requires the root schema to have 'type': 'object', but Pydantic's - discriminated unions generate schemas with oneOf at the root level. + OpenAI requires the root schema to have 'type': 'object' and does not support + oneOf/anyOf anywhere in the schema. Pydantic's discriminated unions generate + schemas with oneOf at the root level, and nullable fields use anyOf, so we + need to flatten and clean them. """ - # Check if the schema already has a root type of "object" - if schema.get("type") == "object": + # Check if the schema already has a root type of "object" and no oneOf/anyOf + if schema.get("type") == "object" and not self._contains_oneof_anyof(schema): return schema - # If it's a oneOf/anyOf schema (discriminated union), wrap it in an object + # If it's a oneOf/anyOf schema (discriminated union), flatten it if "oneOf" in schema or "anyOf" in schema: - return { - "type": "object", - "properties": { - "response": schema - }, - "required": ["response"], - "additionalProperties": False + flattened = self._flatten_discriminated_union(schema) + else: + flattened = schema + + # Clean any remaining oneOf/anyOf structures (like nullable fields) + return self._clean_oneof_anyof_recursive(flattened) + + def _contains_oneof_anyof(self, obj) -> bool: + """Recursively check if an object contains oneOf or anyOf.""" + if isinstance(obj, dict): + if "oneOf" in obj or "anyOf" in obj: + return True + for value in obj.values(): + if isinstance(value, (dict, list)) and self._contains_oneof_anyof(value): + return True + elif isinstance(obj, list): + for item in obj: + if isinstance(item, (dict, list)) and self._contains_oneof_anyof(item): + return True + return False + + def _clean_oneof_anyof_recursive(self, obj): + """Recursively clean oneOf/anyOf structures from a schema.""" + if isinstance(obj, dict): + # Handle nullable fields (anyOf with null) + if "anyOf" in obj: + any_of = obj["anyOf"] + # Check if this is a nullable field pattern: [{"type": "string"}, {"type": "null"}] + if len(any_of) == 2: + types = [] + for item in any_of: + if isinstance(item, dict) and "type" in item: + types.append(item["type"]) + if "null" in types: + # Convert to OpenAI-compatible nullable field + non_null_type = [t for t in types if t != "null"][0] + other_schema = next(item for item in any_of if item.get("type") == non_null_type) + # Create a new dict without anyOf + cleaned = {k: v for k, v in obj.items() if k != "anyOf"} + cleaned.update(other_schema) + # Make it nullable by not including it in required fields + return self._clean_oneof_anyof_recursive(cleaned) + + # Handle oneOf (shouldn't happen after flattening, but just in case) + if "oneOf" in obj: + # This is a complex case - for now, take the first option + # In practice, this shouldn't happen after proper flattening + one_of = obj["oneOf"] + if one_of: + cleaned = {k: v for k, v in obj.items() if k != "oneOf"} + cleaned.update(one_of[0]) + return self._clean_oneof_anyof_recursive(cleaned) + + # Recursively clean all nested objects + return {k: self._clean_oneof_anyof_recursive(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [self._clean_oneof_anyof_recursive(item) for item in obj] + else: + return obj + + def _flatten_discriminated_union(self, schema: dict) -> dict: + """ + Flatten a discriminated union schema to be compatible with OpenAI's structured outputs. + + This converts a oneOf schema into a single object schema with all possible properties + from all union members, making conditional fields optional. + """ + # Get the oneOf alternatives + one_of_schemas = schema.get("oneOf", schema.get("anyOf", [])) + if not one_of_schemas: + return schema + + # Get definitions from the original schema + definitions = schema.get("$defs", {}) + + # Collect all properties and required fields from all alternatives + all_properties = {} + required_fields = set() + + # The discriminator field (usually 'type') is required + discriminator_info = schema.get("discriminator", {}) + discriminator_field = discriminator_info.get("propertyName") + + if discriminator_field: + required_fields.add(discriminator_field) + # Create an enum for all possible discriminator values + possible_values = [] + for alt_schema in one_of_schemas: + # Resolve $ref if present + resolved_schema = self._resolve_schema_ref(alt_schema, definitions) + properties = resolved_schema.get("properties", {}) + if discriminator_field in properties: + disc_prop = properties[discriminator_field] + if "const" in disc_prop: + possible_values.append(disc_prop["const"]) + elif "enum" in disc_prop: + possible_values.extend(disc_prop["enum"]) + + all_properties[discriminator_field] = { + "type": "string", + "enum": possible_values } + # Merge properties from all alternatives + for alt_schema in one_of_schemas: + resolved_schema = self._resolve_schema_ref(alt_schema, definitions) + properties = resolved_schema.get("properties", {}) + + for prop_name, prop_schema in properties.items(): + if prop_name not in all_properties: + all_properties[prop_name] = prop_schema + # Note: We don't make alternative-specific fields required in the flattened schema + # because they're only required conditionally based on the discriminator value + + flattened = { + "type": "object", + "properties": all_properties, + "required": list(required_fields), + "additionalProperties": False + } + + # Include definitions if they exist and are referenced + if definitions: + flattened["$defs"] = definitions + + return flattened + + def _resolve_schema_ref(self, schema: dict, definitions: dict) -> dict: + """Resolve a $ref reference to get the actual schema definition.""" + if "$ref" in schema: + ref_path = schema["$ref"] + if ref_path.startswith("#/$defs/"): + def_name = ref_path.replace("#/$defs/", "") + return definitions.get(def_name, schema) return schema def _extract_response_from_openai_format(self, response_data: Any, original_schema: dict) -> Any: """ - Extract the actual response from OpenAI's format, handling wrapped schemas. + Extract the actual response from OpenAI's format. - If the original schema was wrapped (oneOf/anyOf), unwrap the response. - Otherwise, return the response as-is. + Since we now flatten discriminated unions instead of wrapping them, + we can return the response data as-is. """ - # Determine if the original schema was wrapped - schema_was_wrapped = ( - original_schema.get("type") != "object" and - ("oneOf" in original_schema or "anyOf" in original_schema) - ) - - # If schema was wrapped and response has the wrapped structure, unwrap it - if schema_was_wrapped and isinstance(response_data, dict) and "response" in response_data: - return response_data["response"] - return response_data def _convert_to_openai_tools(self, tools: list[Tool]) -> list[dict[str, Any]]: diff --git a/tests/modelAccessors/test_openai_accessor.py b/tests/modelAccessors/test_openai_accessor.py index e7bfa1a..6b24421 100644 --- a/tests/modelAccessors/test_openai_accessor.py +++ b/tests/modelAccessors/test_openai_accessor.py @@ -124,8 +124,8 @@ def test_tool_support(): def test_prepare_schema_for_openai(): - """Test that discriminated union schemas are properly wrapped for OpenAI compatibility.""" - # Test with discriminated union (should be wrapped) + """Test that discriminated union schemas are properly flattened for OpenAI compatibility.""" + # Test with discriminated union (should be flattened) adapter = cast(TypeAdapter[ModelResponse], TypeAdapter(ClarifierResponse)) union_schema = adapter.json_schema() @@ -133,16 +133,28 @@ def test_prepare_schema_for_openai(): with patch('src.modelAccessors.openai_accessor.OpenAI'): accessor = OpenAIAccessor() - # Test union schema wrapping + # Test union schema flattening fixed_schema = accessor._prepare_schema_for_openai(union_schema) assert fixed_schema["type"] == "object" assert "properties" in fixed_schema - assert "response" in fixed_schema["properties"] - assert fixed_schema["required"] == ["response"] + assert "oneOf" not in fixed_schema + assert "anyOf" not in fixed_schema + assert fixed_schema["required"] == ["type"] assert fixed_schema["additionalProperties"] is False - assert fixed_schema["properties"]["response"] == union_schema - # Test object schema (should not be wrapped) + # Should have properties from both FollowUpResponse and ImplementedResponse + properties = fixed_schema["properties"] + assert "type" in properties + assert "content" in properties # Common field + assert "artifacts" in properties # Common field + assert "follow_up_ask" in properties # FollowUpResponse field + + # Type field should have correct enum values + type_prop = properties["type"] + assert type_prop["type"] == "string" + assert set(type_prop["enum"]) == {"follow_up_required", "implemented"} + + # Test object schema (should not be changed) object_schema = { "type": "object", "properties": {"test": {"type": "string"}}, @@ -157,15 +169,15 @@ def test_extract_response_from_openai_format(): with patch('src.modelAccessors.openai_accessor.OpenAI'): accessor = OpenAIAccessor() - # Test with discriminated union schema (should be unwrapped) + # Test with discriminated union schema (should be returned as-is) adapter = cast(TypeAdapter[ModelResponse], TypeAdapter(ClarifierResponse)) union_schema = adapter.json_schema() - wrapped_response = {"response": {"type": "implemented", "content": "test"}} - extracted = accessor._extract_response_from_openai_format(wrapped_response, union_schema) + response = {"type": "implemented", "content": "test"} + extracted = accessor._extract_response_from_openai_format(response, union_schema) assert extracted == {"type": "implemented", "content": "test"} - # Test with object schema (should not be unwrapped) + # Test with object schema (should be returned as-is) object_schema = { "type": "object", "properties": {"test": {"type": "string"}}, @@ -177,7 +189,7 @@ def test_extract_response_from_openai_format(): def test_response_unwrapping_integration(): - """Test that wrapped responses are properly unwrapped in the full call flow.""" + """Test that flattened responses are properly handled in the full call flow.""" # Mock the OpenAI client and response with patch('src.modelAccessors.openai_accessor.OpenAI') as mock_openai: accessor = OpenAIAccessor() @@ -185,9 +197,9 @@ def test_response_unwrapping_integration(): mock_openai.return_value = mock_client accessor.client = mock_client - # Create mock response with wrapped data + # Create mock response with flattened data (no wrapping) mock_message = Mock() - mock_message.parsed = {"response": {"type": "implemented", "content": "test"}} + mock_message.parsed = {"type": "implemented", "content": "test"} mock_response = Mock() mock_response.choices = [Mock(message=mock_message)] mock_client.chat.completions.create.return_value = mock_response @@ -202,13 +214,13 @@ def test_response_unwrapping_integration(): schema=schema ) - # Verify the response was unwrapped correctly + # Verify the response was handled correctly assert hasattr(result, 'type') assert result.type == "implemented" def test_json_fallback_unwrapping_integration(): - """Test that JSON fallback also unwraps responses correctly.""" + """Test that JSON fallback also handles flattened responses correctly.""" with patch('src.modelAccessors.openai_accessor.OpenAI') as mock_openai: accessor = OpenAIAccessor() mock_client = Mock() @@ -218,7 +230,7 @@ def test_json_fallback_unwrapping_integration(): # Create mock response with JSON content (no parsed attribute) mock_message = Mock() mock_message.parsed = None - mock_message.content = json.dumps({"response": {"type": "implemented", "content": "test"}}) + mock_message.content = json.dumps({"type": "implemented", "content": "test"}) mock_response = Mock() mock_response.choices = [Mock(message=mock_message)] mock_client.chat.completions.create.return_value = mock_response @@ -233,6 +245,41 @@ def test_json_fallback_unwrapping_integration(): schema=schema ) - # Verify the response was unwrapped correctly + # Verify the response was handled correctly assert hasattr(result, 'type') - assert result.type == "implemented" \ No newline at end of file + assert result.type == "implemented" + +def test_full_modelresponse_compatibility(): + """Test that the full ModelResponse schema is OpenAI compatible after flattening.""" + # This tests the exact scenario from the GitHub issue + adapter = TypeAdapter(ModelResponse) + schema = adapter.json_schema() + + with patch('src.modelAccessors.openai_accessor.OpenAI'): + accessor = OpenAIAccessor() + + # The original schema has oneOf which OpenAI doesn't support + assert "oneOf" in schema + + # After preparation, should be fully OpenAI compatible + openai_schema = accessor._prepare_schema_for_openai(schema) + assert openai_schema["type"] == "object" + assert "oneOf" not in str(openai_schema) + assert "anyOf" not in str(openai_schema) + assert openai_schema["required"] == ["type"] + + # Should have all possible properties from all union members + properties = openai_schema["properties"] + expected_props = ["type", "content", "artifacts", "subtasks", "follow_up_ask", "error_message", "retryable"] + for prop in expected_props: + assert prop in properties, f"Missing property: {prop}" + + # Type field should have all discriminator values + type_prop = properties["type"] + assert type_prop["type"] == "string" + expected_types = {"decomposed", "implemented", "follow_up_required", "failed"} + assert set(type_prop["enum"]) == expected_types + + # Should preserve complex type references + assert "$defs" in openai_schema + assert "Task" in openai_schema["$defs"]