Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion flask_openapi3/models/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# @Time : 2023/7/4 9:55
from typing import Any, Union

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_serializer

from .data_type import DataType
from .discriminator import Discriminator
Expand Down Expand Up @@ -57,3 +57,11 @@ class Schema(BaseModel):
const: Any | None = None

model_config = {"populate_by_name": True}

@model_serializer(mode="wrap", when_used="json")
def _serialize(self, serializer, info):
data = serializer(self)
# Remove 'default' key if it's None to maintain OpenAPI spec compliance
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This question was created in #189, @ddorian do you remember why you did it at the time?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(IIRC) There are cases when "default=None" is needed to be documented. Like, the field requires a value, None is valid, and it's the default. I already use openapi-spec-validator and don't see what validation it breaks though from a quick check of the pull request.

if data.get("default") is None:
data.pop("default", None)
return data
3 changes: 2 additions & 1 deletion flask_openapi3/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,8 @@ def generate_spec_json(self):

# Set components
self.components.schemas = self.components_schemas
self.components.securitySchemes = self.security_schemes
if self.security_schemes:
self.components.securitySchemes = self.security_schemes
self.spec.components = self.components

# Convert spec to JSON
Expand Down
35 changes: 33 additions & 2 deletions flask_openapi3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,24 @@ def parse_body(
return content, components_schemas


def validate_links(links: dict) -> dict:
"""
Validates links and returns only valid ones.
Each link must have either operationRef or operationId.

Args:
links: Dictionary of link objects to validate.

Returns:
Dictionary containing only valid links.
"""
valid_links = {}
for link_name, link_obj in links.items():
if isinstance(link_obj, dict) and ("operationRef" in link_obj or "operationId" in link_obj):
valid_links[link_name] = link_obj
return valid_links


def get_responses(responses: ResponseStrKeyDict, components_schemas: dict, operation: Operation) -> None:
_responses = {}
_schemas = {}
Expand All @@ -323,6 +341,14 @@ def get_responses(responses: ResponseStrKeyDict, components_schemas: dict, opera
_responses[key] = Response(description=HTTP_STATUS.get(key, ""))
elif isinstance(response, dict):
response["description"] = response.get("description", HTTP_STATUS.get(key, ""))
# Validate links - each link must have either operationRef or operationId
if "links" in response:
valid_links = validate_links(response.get("links", {}))
# Only set links if there are valid ones, otherwise remove the key
if valid_links:
response["links"] = valid_links
else:
response.pop("links", None)
_responses[key] = Response(**response)
else:
# OpenAPI 3 support ^[a-zA-Z0-9\.\-_]+$ so we should normalize __name__
Expand All @@ -344,7 +370,11 @@ def get_responses(responses: ResponseStrKeyDict, components_schemas: dict, opera
if "headers" in openapi_extra_keys:
_responses[key].headers = openapi_extra.get("headers")
if "links" in openapi_extra_keys:
_responses[key].links = openapi_extra.get("links")
# Validate links - each link must have either operationRef or operationId
valid_links = validate_links(openapi_extra.get("links", {}))
# Only set links if there are valid ones
if valid_links:
_responses[key].links = valid_links
_content = _responses[key].content
if "example" in openapi_extra_keys:
_content["application/json"].example = openapi_extra.get("example") # type: ignore
Expand All @@ -362,7 +392,8 @@ def get_responses(responses: ResponseStrKeyDict, components_schemas: dict, opera
_schemas[normalize_name(name)] = Schema(**value)

components_schemas.update(**_schemas)
operation.responses = _responses
if _responses:
operation.responses = _responses


def parse_and_store_tags(
Expand Down
37 changes: 37 additions & 0 deletions flask_openapi3/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ def wrapper(cls):
name=cls_method.__qualname__, path=rule, method=method
)

# Ensure all path parameters in URI are defined in operation parameters
self._ensure_path_parameters(uri, cls_method.operation)

# Convert route parameters from {param} to <param>
_rule = uri.replace("{", "<").replace("}", ">")
self.views[_rule] = (cls, methods)
Expand All @@ -101,6 +104,40 @@ def wrapper(cls):

return wrapper

def _ensure_path_parameters(self, uri: str, operation):
"""Ensure all path parameters in the URI are defined in the operation's parameters."""
import re

from .models import Parameter, ParameterInType, Schema
from .models.data_type import DataType

# Extract all path parameter names from URI
param_pattern = r"\{([^}]+)\}"
uri_params = set(re.findall(param_pattern, uri))

if not uri_params:
return

# Get existing path parameter names
existing_params = set()
if operation.parameters:
existing_params = {p.name for p in operation.parameters if p.param_in == ParameterInType.PATH}

# Add missing path parameters
missing_params = uri_params - existing_params
if missing_params:
if not operation.parameters:
operation.parameters = []
for param_name in sorted(missing_params): # Sort for consistent ordering
operation.parameters.append(
Parameter(
name=param_name,
param_in=ParameterInType.PATH,
required=True,
param_schema=Schema(type=DataType.STRING),
)
)

def doc(
self,
*,
Expand Down
17 changes: 8 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ doc = [
]
mypy = ["mypy"]
ruff = ["ruff"]
test = ["pytest", "asgiref", "pyyaml"]
test = ["pytest", "asgiref", "pyyaml", "openapi-spec-validator"]

[build-system]
requires = ["hatchling"]
Expand All @@ -75,20 +75,20 @@ include = [
"/tests",
"/CHANGELOG.md",
"/CONTRIBUTING.md",
"/LICENSE.rst"
"/LICENSE.rst",
]

[tool.ruff]
line-length = 120

[tool.ruff.lint]
select = [
"I", # Import related rules
"E4", # Import order
"E7", # Statement issues
"E9", # Runtime errors
"F", # Pyflakes rules
"Q", # flake8-quotes
"I", # Import related rules
"E4", # Import order
"E7", # Statement issues
"E9", # Runtime errors
"F", # Pyflakes rules
"Q", # flake8-quotes
"UP045", # Use X | None for type annotations
]

Expand All @@ -104,4 +104,3 @@ plugins = ["pydantic.mypy"]
module = ["pydantic_core.*", "devtools.*"]
follow_imports = "skip"
ignore_missing_imports = true

5 changes: 5 additions & 0 deletions tests/test_api_blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@


import pytest
from openapi_spec_validator import validate
from pydantic import BaseModel, Field

from flask_openapi3 import APIBlueprint, Info, OpenAPI, Tag
Expand Down Expand Up @@ -112,6 +113,10 @@ def get_book(path: BookPath):
def test_openapi(client):
resp = client.get("/openapi/openapi.json")
assert resp.status_code == 200

# Validate the spec against OpenAPI specification
validate(resp.json)

assert resp.json == app.api_doc
assert resp.json["paths"]["/api/book/{bid}"]["put"]["operationId"] == "update"
assert resp.json["paths"]["/api/book/{bid}"]["delete"]["operationId"] == "delete_book"
Expand Down
5 changes: 5 additions & 0 deletions tests/test_api_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@


import pytest
from openapi_spec_validator import validate
from pydantic import BaseModel, Field

from flask_openapi3 import APIView, Info, OpenAPI, Tag
Expand Down Expand Up @@ -96,6 +97,10 @@ def client():
def test_openapi(client):
resp = client.get("/openapi/openapi.json")
assert resp.status_code == 200

# Validate the spec against OpenAPI specification
validate(resp.json)

assert resp.json == app.api_doc
assert resp.json["paths"]["/api/v1/{name}/book/{id}"]["put"]["operationId"] == "update"
assert (
Expand Down
5 changes: 5 additions & 0 deletions tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@


import pytest
from openapi_spec_validator import validate
from pydantic import BaseModel, Field

from flask_openapi3 import APIView, OpenAPI
Expand Down Expand Up @@ -63,6 +64,10 @@ def client():
def test_openapi(client):
resp = client.get("/openapi/openapi.json")
assert resp.status_code == 200

# Validate the spec against OpenAPI specification
validate(resp.json)

assert resp.json == app.api_doc


Expand Down
5 changes: 5 additions & 0 deletions tests/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from enum import Enum

import pytest
from openapi_spec_validator import validate
from pydantic import BaseModel, Field

from flask_openapi3 import Info, OpenAPI
Expand Down Expand Up @@ -40,6 +41,10 @@ def test_openapi(client):
resp = client.get("/openapi/openapi.json")
_json = resp.json
assert resp.status_code == 200

# Validate the spec against OpenAPI specification
validate(_json)

assert _json["components"]["schemas"].get("Language") is not None

resp = client.get("/English")
Expand Down
4 changes: 4 additions & 0 deletions tests/test_model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# @Author : llc
# @Time : 2023/6/30 10:12
import pytest
from openapi_spec_validator import validate
from pydantic import BaseModel, Field

from flask_openapi3 import FileStorage, OpenAPI
Expand Down Expand Up @@ -89,6 +90,9 @@ def test_openapi(client):
resp = client.get("/openapi/openapi.json")
_json = resp.json
assert resp.status_code == 200

# Validate the spec against OpenAPI specification
validate(_json)
assert _json["paths"]["/form"]["post"]["requestBody"]["content"]["multipart/form-data"]["examples"] == {
"Example 01": {"summary": "An example", "value": {"file": "Example-01.jpg", "str_list": ["a", "b", "c"]}},
"Example 02": {"summary": "Another example", "value": {"str_list": ["1", "2", "3"]}},
Expand Down
5 changes: 5 additions & 0 deletions tests/test_nested_apiblueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# @Time : 2022/4/2 9:09

import pytest
from openapi_spec_validator import validate
from pydantic import BaseModel

from flask_openapi3 import APIBlueprint, OpenAPI, Tag
Expand Down Expand Up @@ -45,6 +46,10 @@ def client():
def test_openapi(client):
resp = client.get("/openapi/openapi.json")
assert resp.status_code == 200

# Validate the spec against OpenAPI specification
validate(resp.json)

assert resp.json == app.api_doc


Expand Down
5 changes: 5 additions & 0 deletions tests/test_number_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# @Time : 2024/4/19 20:53

import pytest
from openapi_spec_validator import validate
from pydantic import BaseModel, Field

from flask_openapi3 import OpenAPI
Expand Down Expand Up @@ -30,6 +31,10 @@ def client():
def test_openapi(client):
resp = client.get("/openapi/openapi.json")
assert resp.status_code == 200

# Validate the spec against OpenAPI specification
validate(resp.json)

assert resp.json == app.api_doc

model_props = resp.json["components"]["schemas"]["MyModel"]["properties"]
Expand Down
Loading