Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/pull_request_template.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ Checklist:
- [ ] Run `pytest tests` and no failed.
- [ ] Run `ruff check star_openapi tests examples` and no failed.
- [ ] Run `ruff format star_openapi tests examples` and no failed.
- [ ] Run `mypy star_openapi` and no failed.
- [ ] Run `ty check star_openapi` and no failed.
- [ ] Run `mkdocs serve` and no failed.
10 changes: 2 additions & 8 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,5 @@ jobs:
- name: ruff format
run: uv run ruff format --check star_openapi tests examples

- name: Cache mypy
uses: actions/cache@v5
with:
path: ./.mypy_cache
key: mypy|${{ matrix.python-version }}|${{ hashFiles('pyproject.toml') }}

- name: Run mypy
run: uv run mypy star_openapi
- name: Run ty
run: uv run ty check star_openapi
7 changes: 4 additions & 3 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,16 @@ Before submitting pr, you need to complete the following steps:
ruff check star_openapi tests examples
```

4. Running the mypy
4. Running the ty

```bash
mypy star_openapi
ty check star_openapi
```

5. Building the docs

Serve the live docs with [Material for MkDocs](https://github.com/squidfunk/mkdocs-material), and make sure it's correct.
Serve the live docs with [Material for MkDocs](https://github.com/squidfunk/mkdocs-material), and make sure it's
correct.

```bash
mkdocs serve
Expand Down
10 changes: 1 addition & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ build-backend = "hatchling.build"
path = "star_openapi/__version__.py"

[dependency-groups]
mypy = ["mypy"]
ty = ["ty"]
ruff = ["ruff"]
test = ["pytest", "httpx"]
doc = [
Expand Down Expand Up @@ -109,11 +109,3 @@ select = [
[tool.ruff.lint.per-file-ignores]
"star_openapi/__init__.py" = ["F401"]
"star_openapi/models/__init__.py" = ["F401"]

[tool.mypy]
plugins = ["pydantic.mypy"]

[[tool.mypy.overrides]]
module = ["pydantic_core.*", "devtools.*"]
follow_imports = "skip"
ignore_missing_imports = true
4 changes: 2 additions & 2 deletions star_openapi/models/header.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class Header(Parameter):
https://spec.openapis.org/oas/v3.1.0#header-object
"""

name: str | None = None # type:ignore
param_in: ParameterInType | None = None # type:ignore
name: str | None = None
param_in: ParameterInType | None = None

model_config = {"extra": "allow"}
28 changes: 15 additions & 13 deletions star_openapi/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from http import HTTPMethod
from importlib import import_module
from importlib.metadata import entry_points
from types import FunctionType
from typing import Any, Type

from jinja2 import Template
Expand Down Expand Up @@ -110,7 +111,7 @@ def __init__(
self.components = Components()

# Initialize lists for tags and tag names
self.tags: list[Tag | dict[str, Any]] = []
self.tags: list[Tag] = []
self.tag_names: list[str] = []

# Set URL prefixes and endpoints
Expand Down Expand Up @@ -141,7 +142,6 @@ def __init__(

# Initialize specification JSON
self.spec_json: dict[str, Any] = {}
self.spec = OpenAPISpec(openapi=self.openapi_version, info=self.info, paths=self.paths)

self.cli = cli

Expand Down Expand Up @@ -197,19 +197,23 @@ def api_doc(self) -> dict:
return self.spec_json

def generate_spec_json(self):
self.spec.openapi = self.openapi_version
self.spec.info = self.info
self.spec.paths = self.paths
if isinstance(self.info, dict):
self.info = Info.model_validate(self.info)
spec = OpenAPISpec(openapi=self.openapi_version, info=self.info, paths=self.paths)
spec.openapi = self.openapi_version
spec.info = self.info

if self.severs:
self.spec.servers = self.severs
spec.servers = [Server(**server) if isinstance(server, dict) else server for server in self.severs]

if self.external_docs:
self.spec.externalDocs = self.external_docs
if isinstance(self.external_docs, dict):
self.external_docs = ExternalDocumentation.model_validate(self.external_docs)
spec.externalDocs = self.external_docs

# Set tags
if self.tags:
self.spec.tags = self.tags
spec.tags = self.tags

# Add ValidationErrorModel to components schemas
schema = get_model_schema(self.validation_error_model)
Expand All @@ -223,10 +227,10 @@ def generate_spec_json(self):
# Set components
self.components.schemas = self.components_schemas
self.components.securitySchemes = self.security_schemes
self.spec.components = self.components
spec.components = self.components

# Convert spec to JSON
self.spec_json = self.spec.model_dump(mode="json", by_alias=True, exclude_unset=True, warnings=False)
self.spec_json = spec.model_dump(mode="json", by_alias=True, exclude_unset=True, warnings=False)

# Update with OpenAPI extensions
self.spec_json.update(**self.openapi_extensions)
Expand All @@ -252,8 +256,6 @@ def generate_spec_json(self):

def register_api(self, api: APIRouter):
for tag in api.tags:
if isinstance(tag, dict):
tag = Tag(**tag)
if tag.name not in self.tag_names:
# Append tag to the list of tags
self.tags.append(tag)
Expand Down Expand Up @@ -283,7 +285,7 @@ def register_api(self, api: APIRouter):
def _collect_openapi_info(
self,
rule: str,
func: Callable,
func: FunctionType,
*,
tags: list[Tag | dict[str, Any]] | None = None,
summary: str | None = None,
Expand Down
2 changes: 1 addition & 1 deletion star_openapi/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ async def _validate_header(request: Request, header: Type[BaseModel]):
if value is not None:
header_dict[key] = value
if model_field_schema.get("type") == "null":
header_dict[key] = value # type:ignore
header_dict[key] = value
# extra keys
for key, value in request_headers.items():
if key not in header_dict.keys():
Expand Down
10 changes: 5 additions & 5 deletions star_openapi/router.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Callable
from http import HTTPMethod
from types import FunctionType
from typing import Any

from starlette.routing import Route, Router, WebSocketRoute
Expand Down Expand Up @@ -48,7 +49,8 @@ def __init__(
self.url_prefix = url_prefix
self.paths: dict[str, Any] = {}
self.components_schemas: dict[str, Any] = {}
self.tags = tags or []
self.tags: list[Tag] = []
self.api_tags = tags or []
self.tag_names: list[str] = []
self.security = security or []
self.operation_id_callback = operation_id_callback
Expand All @@ -57,8 +59,6 @@ def __init__(

def register_api(self, api: "APIRouter"):
for tag in api.tags:
if isinstance(tag, dict):
tag = Tag(**tag)
if tag.name not in self.tag_names:
# Append tag to the list of tags
self.tags.append(tag)
Expand Down Expand Up @@ -88,7 +88,7 @@ def register_api(self, api: "APIRouter"):
def _collect_openapi_info(
self,
rule: str,
func: Callable,
func: FunctionType,
*,
tags: list[Tag | dict[str, Any]] | None = None,
summary: str | None = None,
Expand Down Expand Up @@ -142,7 +142,7 @@ def _collect_openapi_info(
operation.servers = servers

# Store tags
tags = (tags or []) + self.tags
tags = (tags or []) + self.api_tags
parse_and_store_tags(tags, self.tags, self.tag_names, operation)

# Parse rule: merge url_prefix
Expand Down
30 changes: 13 additions & 17 deletions star_openapi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ def get_operation(
operation_dict = {}

if summary:
operation_dict["summary"] = summary # type: ignore
operation_dict["summary"] = summary

if description:
operation_dict["description"] = description # type: ignore
operation_dict["description"] = description

# Add any additional openapi_extensions to the operation dictionary
operation_dict.update(openapi_extensions or {})
Expand Down Expand Up @@ -135,7 +135,7 @@ def parse_header(header: Type[BaseModel]) -> tuple[list[Parameter], dict]:
data["example"] = value.get("example")
if "examples" in value.keys():
data["examples"] = value.get("examples")
parameters.append(Parameter(**data))
parameters.append(Parameter.model_validate(data))

# Parse definitions
definitions = schema.get("$defs", {})
Expand Down Expand Up @@ -168,7 +168,7 @@ def parse_cookie(cookie: Type[BaseModel]) -> tuple[list[Parameter], dict]:
data["example"] = value.get("example")
if "examples" in value.keys():
data["examples"] = value.get("examples")
parameters.append(Parameter(**data))
parameters.append(Parameter.model_validate(data))

# Parse definitions
definitions = schema.get("$defs", {})
Expand Down Expand Up @@ -196,7 +196,7 @@ def parse_path(path: Type[BaseModel]) -> tuple[list[Parameter], dict]:
data["example"] = value.get("example")
if "examples" in value.keys():
data["examples"] = value.get("examples")
parameters.append(Parameter(**data))
parameters.append(Parameter.model_validate(data))

# Parse definitions
definitions = schema.get("$defs", {})
Expand Down Expand Up @@ -229,7 +229,7 @@ def parse_query(query: Type[BaseModel]) -> tuple[list[Parameter], dict]:
data["example"] = value.get("example")
if "examples" in value.keys():
data["examples"] = value.get("examples")
parameters.append(Parameter(**data))
parameters.append(Parameter.model_validate(data))

# Parse definitions
definitions = schema.get("$defs", {})
Expand All @@ -256,11 +256,7 @@ def parse_form(
for k, v in properties.items():
if v.get("type") == "array":
encoding[k] = Encoding(style="form", explode=True)
content = {
"multipart/form-data": MediaType(
schema=Schema(**{"$ref": f"{OPENAPI3_REF_PREFIX}/{title}"}),
)
}
content = {"multipart/form-data": MediaType.model_validate({"schema": {"$ref": f"{OPENAPI3_REF_PREFIX}/{title}"}})}
if encoding:
content["multipart/form-data"].encoding = encoding

Expand All @@ -282,7 +278,7 @@ def parse_body(
original_title = schema.get("title") or body.__name__
title = normalize_name(original_title)
components_schemas[title] = Schema(**schema)
content = {"application/json": MediaType(schema=Schema(**{"$ref": f"{OPENAPI3_REF_PREFIX}/{title}"}))}
content = {"application/json": MediaType.model_validate({"schema": {"$ref": f"{OPENAPI3_REF_PREFIX}/{title}"}})}

# Parse definitions
definitions = schema.get("$defs", {})
Expand All @@ -294,7 +290,7 @@ def parse_body(

def parse_and_store_tags(
new_tags: list[Tag | dict[str, Any]],
old_tags: list[Tag | dict[str, Any]],
old_tags: list[Tag],
old_tag_names: list[str],
operation: Operation,
) -> None:
Expand Down Expand Up @@ -344,10 +340,10 @@ def get_responses(responses: ResponseStrKeyDict, components_schemas: dict, opera
schema = get_model_schema(response, mode="serialization")
original_title = schema.get("title") or response.__name__
name = normalize_name(original_title)
_responses[key] = Response(
description=HTTP_STATUS.get(key, ""),
content={"application/json": MediaType(schema=Schema(**{"$ref": f"{OPENAPI3_REF_PREFIX}/{name}"}))},
)
_responses[key] = Response(description=HTTP_STATUS.get(key, ""))
_responses[key].content = {
"application/json": MediaType.model_validate({"schema": {"$ref": f"{OPENAPI3_REF_PREFIX}/{name}"}})
}

_schemas[name] = Schema(**schema)
definitions = schema.get("$defs")
Expand Down
Loading