diff --git a/stubs/peewee/peewee.pyi b/stubs/peewee/peewee.pyi index 816b7a7e8d0c..950cc1153bf5 100644 --- a/stubs/peewee/peewee.pyi +++ b/stubs/peewee/peewee.pyi @@ -5,7 +5,7 @@ from collections.abc import Callable, Generator, Iterable, Iterator from datetime import datetime from decimal import Decimal from types import TracebackType -from typing import Any, ClassVar, Final, Literal, NamedTuple, NoReturn, TypeVar, overload, type_check_only +from typing import Any, ClassVar, Final, Literal, NamedTuple, NoReturn, TypeAlias, TypeVar, overload, type_check_only from typing_extensions import Self, TypeIs from uuid import UUID @@ -18,6 +18,7 @@ def reraise(tp: Unused, value: BaseException, tb: TracebackType | None = None) - _T = TypeVar("_T") _VT = TypeVar("_VT") _F = TypeVar("_F", bound=Callable[..., Any]) +_DatabaseType: TypeAlias = Database | DatabaseProxy class attrdict(dict[str, _VT]): def __getattr__(self, attr: str) -> _VT: ... @@ -192,7 +193,7 @@ class BaseTable(Source): class _BoundTableContext(_callable_context_manager): table: Incomplete database: Incomplete - def __init__(self, table, database) -> None: ... + def __init__(self, table, database: _DatabaseType) -> None: ... def __enter__(self): ... def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None @@ -206,8 +207,8 @@ class Table(_HashableSource, BaseTable): # type: ignore[misc] self, name, columns=None, primary_key=None, schema: str | None = None, alias=None, _model=None, _database=None ) -> None: ... def clone(self) -> Table: ... - def bind(self, database=None) -> Self: ... - def bind_ctx(self, database=None) -> _BoundTableContext: ... + def bind(self, database: _DatabaseType | None = None) -> Self: ... + def bind_ctx(self, database: _DatabaseType | None = None) -> _BoundTableContext: ... def select(self, *columns) -> Select: ... def insert(self, insert=None, columns=None, **kwargs) -> Insert: ... def replace(self, insert=None, columns=None, **kwargs): ... @@ -529,7 +530,7 @@ class OnConflict(Node): class BaseQuery(Node): default_row_type: Incomplete def __init__(self, _database=None, **kwargs) -> None: ... - def bind(self, database=None) -> Self: ... + def bind(self, database: _DatabaseType | None = None) -> Self: ... def clone(self) -> Self: ... def dicts(self, as_dict: bool = True) -> Self: ... def tuples(self, as_tuple: bool = True) -> Self: ... @@ -537,8 +538,8 @@ class BaseQuery(Node): def objects(self, constructor=None) -> Self: ... def __sql__(self, ctx) -> None: ... def sql(self): ... - def execute(self, database=None): ... - def iterator(self, database=None): ... + def execute(self, database: _DatabaseType | None = None): ... + def iterator(self, database: _DatabaseType | None = None): ... def __iter__(self): ... def __getitem__(self, value): ... def __len__(self) -> int: ... @@ -575,20 +576,20 @@ class SelectQuery(Query): def select_from(self, *columns) -> Select: ... class SelectBase(_HashableSource, Source, SelectQuery): # type: ignore[misc] - def peek(self, database=None, n: int = 1): ... - def first(self, database=None, n: int = 1): ... - def scalar(self, database=None, as_tuple: bool = False, as_dict: bool = False): ... - def scalars(self, database=None) -> Generator[Incomplete]: ... - def count(self, database=None, clear_limit: bool = False) -> int: ... - def exists(self, database=None) -> bool: ... - def get(self, database=None): ... + def peek(self, database: _DatabaseType | None = None, n: int = 1): ... + def first(self, database: _DatabaseType | None = None, n: int = 1): ... + def scalar(self, database: _DatabaseType | None = None, as_tuple: bool = False, as_dict: bool = False): ... + def scalars(self, database: _DatabaseType | None = None) -> Generator[Incomplete]: ... + def count(self, database: _DatabaseType | None = None, clear_limit: bool = False) -> int: ... + def exists(self, database: _DatabaseType | None = None) -> bool: ... + def get(self, database: _DatabaseType | None = None): ... class CompoundSelectQuery(SelectBase): lhs: Incomplete op: Incomplete rhs: Incomplete def __init__(self, lhs, op, rhs) -> None: ... - def exists(self, database=None) -> bool: ... + def exists(self, database: _DatabaseType | None = None) -> bool: ... def __sql__(self, ctx): ... class Select(SelectBase): @@ -635,8 +636,8 @@ class _WriteQuery(Query): def cte(self, name, recursive: bool = False, columns=None, materialized=None) -> CTE: ... def returning(self, *returning) -> Self: ... def apply_returning(self, ctx): ... - def execute_returning(self, database): ... - def handle_result(self, database, cursor): ... + def execute_returning(self, database: _DatabaseType): ... + def handle_result(self, database: _DatabaseType, cursor): ... def __sql__(self, ctx): ... class Update(_WriteQuery): @@ -660,7 +661,7 @@ class Insert(_WriteQuery): def get_default_data(self): ... def get_default_columns(self) -> list[Incomplete] | None: ... def __sql__(self, ctx): ... - def handle_result(self, database, cursor): ... + def handle_result(self, database: _DatabaseType, cursor): ... class Delete(_WriteQuery): def __sql__(self, ctx): ... @@ -778,7 +779,7 @@ class Database(_callable_context_manager): connect_params: Incomplete def __init__( self, - database, + database: str | None, thread_safe: bool = True, autorollback: bool = False, field_types=None, @@ -789,7 +790,7 @@ class Database(_callable_context_manager): ) -> None: ... database: Incomplete deferred: Incomplete - def init(self, database, **kwargs) -> None: ... + def init(self, database: str, **kwargs) -> None: ... def __enter__(self) -> Self: ... def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None @@ -857,10 +858,10 @@ class SqliteDatabase(Database): truncate_table: bool nulls_ordering: bool def __init__( - self, database, pragmas=None, regexp_function: bool = False, rank_functions: bool = False, *args, **kwargs + self, database: str | None, pragmas=None, regexp_function: bool = False, rank_functions: bool = False, *args, **kwargs ) -> None: ... returning_clause: Incomplete - def init(self, database, pragmas=None, timeout: int = 5, returning_clause=None, **kwargs) -> None: ... + def init(self, database: str, pragmas=None, timeout: int = 5, returning_clause=None, **kwargs) -> None: ... def pragma(self, key, value=..., permanent: bool = False, schema: str | None = None): ... cache_size: Incomplete foreign_keys: Incomplete @@ -968,7 +969,7 @@ class PostgresqlDatabase(Database): psycopg3_adapter: Incomplete def init( self, - database, + database: str | None, register_unicode: bool = True, encoding=None, isolation_level=None, @@ -1012,7 +1013,7 @@ class MySQLDatabase(Database): safe_create_index: bool safe_drop_index: bool sql_mode: str - def init(self, database, **kwargs) -> None: ... + def init(self, database: str | None, **kwargs) -> None: ... def is_connection_usable(self) -> bool: ... def default_values_insert(self, ctx): ... def begin(self, isolation_level: str | None = None) -> None: ... @@ -1515,7 +1516,7 @@ class _SortedFieldList: class SchemaManager: model: Incomplete context_options: Incomplete - def __init__(self, model, database=None, **context_options) -> None: ... + def __init__(self, model, database: _DatabaseType | None = None, **context_options) -> None: ... @property def database(self): ... @@ -1569,7 +1570,7 @@ class Metadata: def __init__( self, model, - database=None, + database: _DatabaseType | None = None, table_name=None, indexes=None, primary_key=None, @@ -1616,7 +1617,7 @@ class Metadata: def get_primary_keys(self): ... def get_default_dict(self): ... def fields_to_index(self) -> list[Incomplete]: ... - def set_database(self, database) -> None: ... + def set_database(self, database: _DatabaseType) -> None: ... def set_table_name(self, table_name) -> None: ... class SubclassAwareMetadata(Metadata): @@ -1644,7 +1645,7 @@ class _BoundModelsContext(_callable_context_manager): database: Incomplete bind_refs: Incomplete bind_backrefs: Incomplete - def __init__(self, models, database, bind_refs, bind_backrefs) -> None: ... + def __init__(self, models, database: _DatabaseType, bind_refs, bind_backrefs) -> None: ... def __enter__(self): ... def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None @@ -1713,9 +1714,9 @@ class Model(metaclass=ModelBase): def __ne__(self, other) -> Expression | bool: ... # type: ignore[override] def __sql__(self, ctx): ... @classmethod - def bind(cls, database, bind_refs: bool = True, bind_backrefs: bool = True, _exclude=None) -> bool: ... + def bind(cls, database: _DatabaseType, bind_refs: bool = True, bind_backrefs: bool = True, _exclude=None) -> bool: ... @classmethod - def bind_ctx(cls, database, bind_refs: bool = True, bind_backrefs: bool = True) -> _BoundModelsContext: ... + def bind_ctx(cls, database: _DatabaseType, bind_refs: bool = True, bind_backrefs: bool = True) -> _BoundModelsContext: ... @classmethod def table_exists(cls): ... @classmethod @@ -1774,8 +1775,8 @@ class BaseModelSelect(_ModelQueryHelper): __sub__ = except_ def __iter__(self): ... def prefetch(self, *subqueries): ... - def get(self, database=None): ... - def get_or_none(self, database=None): ... + def get(self, database: _DatabaseType | None = None): ... + def get_or_none(self, database: _DatabaseType | None = None): ... def group_by(self, *columns) -> Self: ... class ModelCompoundSelectQuery(BaseModelSelect, CompoundSelectQuery): # type: ignore[misc]