diff --git a/cog_safe_push/schema.py b/cog_safe_push/schema.py index 0bcd722..703c9c5 100644 --- a/cog_safe_push/schema.py +++ b/cog_safe_push/schema.py @@ -36,8 +36,8 @@ def check_backwards_compatible( input_name = "TrainingInput" if train else "Input" output_name = "TrainingOutput" if train else "Output" - test_inputs = test_model_schemas[input_name] - inputs = model_schemas[input_name] + test_inputs = test_model_schemas[input_name]["properties"] + inputs = model_schemas[input_name]["properties"] errors = [] for name, spec in inputs.items(): diff --git a/test/test_schema.py b/test/test_schema.py index 037d645..f30eb8b 100644 --- a/test/test_schema.py +++ b/test/test_schema.py @@ -10,21 +10,33 @@ ) +def make_input_schema(properties: dict) -> dict: + """Helper to create a properly structured Input schema.""" + return {"type": "object", "title": "Input", "properties": properties} + + def test_identical_schemas(): old = new = { - "Input": {"text": {"type": "string"}, "number": {"type": "integer"}}, + "Input": make_input_schema( + {"text": {"type": "string"}, "number": {"type": "integer"}} + ), "Output": {"type": "string"}, } check_backwards_compatible(new, old, train=False) # Should not raise def test_new_optional_input(): - old = {"Input": {"text": {"type": "string"}}, "Output": {"type": "string"}} + old = { + "Input": make_input_schema({"text": {"type": "string"}}), + "Output": {"type": "string"}, + } new = { - "Input": { - "text": {"type": "string"}, - "optional": {"type": "string", "default": "value"}, - }, + "Input": make_input_schema( + { + "text": {"type": "string"}, + "optional": {"type": "string", "default": "value"}, + } + ), "Output": {"type": "string"}, } check_backwards_compatible(new, old, train=False) # Should not raise @@ -32,17 +44,28 @@ def test_new_optional_input(): def test_removed_input(): old = { - "Input": {"text": {"type": "string"}, "number": {"type": "integer"}}, + "Input": make_input_schema( + {"text": {"type": "string"}, "number": {"type": "integer"}} + ), + "Output": {"type": "string"}, + } + new = { + "Input": make_input_schema({"text": {"type": "string"}}), "Output": {"type": "string"}, } - new = {"Input": {"text": {"type": "string"}}, "Output": {"type": "string"}} with pytest.raises(IncompatibleSchemaError, match="Missing input number"): check_backwards_compatible(new, old, train=False) def test_changed_input_type(): - old = {"Input": {"value": {"type": "integer"}}, "Output": {"type": "string"}} - new = {"Input": {"value": {"type": "string"}}, "Output": {"type": "string"}} + old = { + "Input": make_input_schema({"value": {"type": "integer"}}), + "Output": {"type": "string"}, + } + new = { + "Input": make_input_schema({"value": {"type": "string"}}), + "Output": {"type": "string"}, + } with pytest.raises( IncompatibleSchemaError, match="Input value has changed type from integer to string", @@ -51,9 +74,12 @@ def test_changed_input_type(): def test_added_minimum_constraint(): - old = {"Input": {"value": {"type": "integer"}}, "Output": {"type": "string"}} + old = { + "Input": make_input_schema({"value": {"type": "integer"}}), + "Output": {"type": "string"}, + } new = { - "Input": {"value": {"type": "integer", "minimum": 0}}, + "Input": make_input_schema({"value": {"type": "integer", "minimum": 0}}), "Output": {"type": "string"}, } with pytest.raises( @@ -64,11 +90,11 @@ def test_added_minimum_constraint(): def test_increased_minimum(): old = { - "Input": {"value": {"type": "integer", "minimum": 0}}, + "Input": make_input_schema({"value": {"type": "integer", "minimum": 0}}), "Output": {"type": "string"}, } new = { - "Input": {"value": {"type": "integer", "minimum": 1}}, + "Input": make_input_schema({"value": {"type": "integer", "minimum": 1}}), "Output": {"type": "string"}, } with pytest.raises( @@ -78,9 +104,12 @@ def test_increased_minimum(): def test_added_maximum_constraint(): - old = {"Input": {"value": {"type": "integer"}}, "Output": {"type": "string"}} + old = { + "Input": make_input_schema({"value": {"type": "integer"}}), + "Output": {"type": "string"}, + } new = { - "Input": {"value": {"type": "integer", "maximum": 100}}, + "Input": make_input_schema({"value": {"type": "integer", "maximum": 100}}), "Output": {"type": "string"}, } with pytest.raises( @@ -91,11 +120,11 @@ def test_added_maximum_constraint(): def test_decreased_maximum(): old = { - "Input": {"value": {"type": "integer", "maximum": 100}}, + "Input": make_input_schema({"value": {"type": "integer", "maximum": 100}}), "Output": {"type": "string"}, } new = { - "Input": {"value": {"type": "integer", "maximum": 99}}, + "Input": make_input_schema({"value": {"type": "integer", "maximum": 99}}), "Output": {"type": "string"}, } with pytest.raises( @@ -106,12 +135,16 @@ def test_decreased_maximum(): def test_changed_choice_type(): old = { - "Input": {"choice": {"allOf": [{"$ref": "#/components/schemas/choice"}]}}, + "Input": make_input_schema( + {"choice": {"allOf": [{"$ref": "#/components/schemas/choice"}]}} + ), "choice": {"type": "string", "enum": ["A", "B", "C"]}, "Output": {"type": "string"}, } new = { - "Input": {"choice": {"allOf": [{"$ref": "#/components/schemas/choice"}]}}, + "Input": make_input_schema( + {"choice": {"allOf": [{"$ref": "#/components/schemas/choice"}]}} + ), "choice": {"type": "integer", "enum": [1, 2, 3]}, "Output": {"type": "string"}, } @@ -124,12 +157,16 @@ def test_changed_choice_type(): def test_added_choice(): old = { - "Input": {"choice": {"allOf": [{"$ref": "#/components/schemas/choice"}]}}, + "Input": make_input_schema( + {"choice": {"allOf": [{"$ref": "#/components/schemas/choice"}]}} + ), "choice": {"type": "string", "enum": ["A", "B", "C"]}, "Output": {"type": "string"}, } new = { - "Input": {"choice": {"allOf": [{"$ref": "#/components/schemas/choice"}]}}, + "Input": make_input_schema( + {"choice": {"allOf": [{"$ref": "#/components/schemas/choice"}]}} + ), "choice": {"type": "string", "enum": ["A", "B", "C", "D"]}, "Output": {"type": "string"}, } @@ -138,12 +175,16 @@ def test_added_choice(): def test_removed_choice(): old = { - "Input": {"choice": {"allOf": [{"$ref": "#/components/schemas/choice"}]}}, + "Input": make_input_schema( + {"choice": {"allOf": [{"$ref": "#/components/schemas/choice"}]}} + ), "choice": {"type": "string", "enum": ["A", "B", "C"]}, "Output": {"type": "string"}, } new = { - "Input": {"choice": {"allOf": [{"$ref": "#/components/schemas/choice"}]}}, + "Input": make_input_schema( + {"choice": {"allOf": [{"$ref": "#/components/schemas/choice"}]}} + ), "choice": {"type": "string", "enum": ["A", "B"]}, "Output": {"type": "string"}, } @@ -154,9 +195,14 @@ def test_removed_choice(): def test_new_required_input(): - old = {"Input": {"text": {"type": "string"}}, "Output": {"type": "string"}} + old = { + "Input": make_input_schema({"text": {"type": "string"}}), + "Output": {"type": "string"}, + } new = { - "Input": {"text": {"type": "string"}, "new_required": {"type": "string"}}, + "Input": make_input_schema( + {"text": {"type": "string"}, "new_required": {"type": "string"}} + ), "Output": {"type": "string"}, } with pytest.raises( @@ -166,29 +212,39 @@ def test_new_required_input(): def test_changed_output_type(): - old = {"Input": {}, "Output": {"type": "string"}} - new = {"Input": {}, "Output": {"type": "integer"}} + old = { + "Input": make_input_schema({}), + "Output": {"type": "string"}, + } + new = { + "Input": make_input_schema({}), + "Output": {"type": "integer"}, + } with pytest.raises(IncompatibleSchemaError, match="Output has changed type"): check_backwards_compatible(new, old, train=False) def test_multiple_incompatibilities(): old = { - "Input": { - "text": {"type": "string"}, - "number": {"type": "integer", "minimum": 0}, - "choice": {"allOf": [{"$ref": "#/components/schemas/choice"}]}, - }, + "Input": make_input_schema( + { + "text": {"type": "string"}, + "number": {"type": "integer", "minimum": 0}, + "choice": {"allOf": [{"$ref": "#/components/schemas/choice"}]}, + } + ), "choice": {"type": "string", "enum": ["A", "B", "C"]}, "Output": {"type": "string"}, } new = { - "Input": { - "text": {"type": "integer"}, - "number": {"type": "integer", "minimum": 1}, - "choice": {"allOf": [{"$ref": "#/components/schemas/choice"}]}, - "new_required": {"type": "string"}, - }, + "Input": make_input_schema( + { + "text": {"type": "integer"}, + "number": {"type": "integer", "minimum": 1}, + "choice": {"allOf": [{"$ref": "#/components/schemas/choice"}]}, + "new_required": {"type": "string"}, + } + ), "choice": {"type": "string", "enum": ["A", "B"]}, "Output": {"type": "integer"}, } @@ -202,6 +258,137 @@ def test_multiple_incompatibilities(): assert "Output has changed type" in error_message +def test_training_input_schema(): + """Test that train=True uses TrainingInput instead of Input.""" + old = { + "TrainingInput": make_input_schema({"data": {"type": "string"}}), + "TrainingOutput": {"type": "string"}, + } + new = { + "TrainingInput": make_input_schema({"data": {"type": "integer"}}), + "TrainingOutput": {"type": "string"}, + } + with pytest.raises( + IncompatibleSchemaError, + match="TrainingInput data has changed type from string to integer", + ): + check_backwards_compatible(new, old, train=True) + + +def test_realistic_schema_structure(): + """Test with a realistic schema structure from get_schemas(). + + This guards against the bug where check_backwards_compatible expected + Input to be the properties dict directly, but get_schemas returns: + {"Input": {"type": "object", "properties": {...}}} + """ + old = { + "Input": { + "type": "object", + "title": "Input", + "properties": { + "seed": { + "type": "integer", + "title": "Seed", + "x-order": 8, + "nullable": True, + "description": "Random seed. Set for reproducible generation", + }, + "image": { + "type": "string", + "title": "Image", + "format": "uri", + "x-order": 0, + "nullable": True, + "description": "Image file", + }, + "prompt": { + "type": "string", + "title": "Prompt", + "x-order": 6, + "nullable": True, + "description": "Text prompt for image generation", + }, + }, + }, + "Output": {"type": "string"}, + } + new = { + "Input": { + "type": "object", + "title": "Input", + "properties": { + "seed": { + "type": "integer", + "title": "Seed", + "x-order": 8, + "nullable": True, + "description": "Random seed. Set for reproducible generation", + }, + "image": { + "type": "string", + "title": "Image", + "format": "uri", + "x-order": 0, + "nullable": True, + "description": "Image file", + }, + "prompt": { + "type": "string", + "title": "Prompt", + "x-order": 6, + "nullable": True, + "description": "Text prompt for image generation", + }, + }, + }, + "Output": {"type": "string"}, + } + check_backwards_compatible(new, old, train=False) # Should not raise + + +def test_realistic_schema_structure_with_incompatibility(): + """Test realistic schema detects incompatibilities.""" + old = { + "Input": { + "type": "object", + "title": "Input", + "properties": { + "steps": { + "type": "integer", + "title": "Steps", + "minimum": 1, + "maximum": 100, + "default": 25, + "x-order": 0, + "description": "Number of steps", + }, + }, + }, + "Output": {"type": "string"}, + } + new = { + "Input": { + "type": "object", + "title": "Input", + "properties": { + "steps": { + "type": "integer", + "title": "Steps", + "minimum": 10, # Higher minimum - breaking change + "maximum": 100, + "default": 25, + "x-order": 0, + "description": "Number of steps", + }, + }, + }, + "Output": {"type": "string"}, + } + with pytest.raises(IncompatibleSchemaError, match="steps has a higher minimum"): + check_backwards_compatible(new, old, train=False) + + def test_lint_deprecated_input_without_description(): mock_model = Mock() mock_model.versions.list.return_value = [