Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 33 additions & 32 deletions stubs/peewee/peewee.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ from collections.abc import Callable, Generator, Iterable, Iterator
from datetime import datetime
from decimal import Decimal
from types import TracebackType
from typing import Any, ClassVar, Final, Literal, NamedTuple, NoReturn, TypeVar, overload, type_check_only
from typing import Any, ClassVar, Final, Literal, NamedTuple, NoReturn, TypeAlias, TypeVar, overload, type_check_only
from typing_extensions import Self, TypeIs
from uuid import UUID

Expand All @@ -18,6 +18,7 @@ def reraise(tp: Unused, value: BaseException, tb: TracebackType | None = None) -
_T = TypeVar("_T")
_VT = TypeVar("_VT")
_F = TypeVar("_F", bound=Callable[..., Any])
_DatabaseType: TypeAlias = Database | DatabaseProxy

class attrdict(dict[str, _VT]):
def __getattr__(self, attr: str) -> _VT: ...
Expand Down Expand Up @@ -192,7 +193,7 @@ class BaseTable(Source):
class _BoundTableContext(_callable_context_manager):
table: Incomplete
database: Incomplete
def __init__(self, table, database) -> None: ...
def __init__(self, table, database: _DatabaseType) -> None: ...
def __enter__(self): ...
def __exit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
Expand All @@ -206,8 +207,8 @@ class Table(_HashableSource, BaseTable): # type: ignore[misc]
self, name, columns=None, primary_key=None, schema: str | None = None, alias=None, _model=None, _database=None
) -> None: ...
def clone(self) -> Table: ...
def bind(self, database=None) -> Self: ...
def bind_ctx(self, database=None) -> _BoundTableContext: ...
def bind(self, database: _DatabaseType | None = None) -> Self: ...
def bind_ctx(self, database: _DatabaseType | None = None) -> _BoundTableContext: ...
def select(self, *columns) -> Select: ...
def insert(self, insert=None, columns=None, **kwargs) -> Insert: ...
def replace(self, insert=None, columns=None, **kwargs): ...
Expand Down Expand Up @@ -529,16 +530,16 @@ class OnConflict(Node):
class BaseQuery(Node):
default_row_type: Incomplete
def __init__(self, _database=None, **kwargs) -> None: ...
def bind(self, database=None) -> Self: ...
def bind(self, database: _DatabaseType | None = None) -> Self: ...
def clone(self) -> Self: ...
def dicts(self, as_dict: bool = True) -> Self: ...
def tuples(self, as_tuple: bool = True) -> Self: ...
def namedtuples(self, as_namedtuple: bool = True) -> Self: ...
def objects(self, constructor=None) -> Self: ...
def __sql__(self, ctx) -> None: ...
def sql(self): ...
def execute(self, database=None): ...
def iterator(self, database=None): ...
def execute(self, database: _DatabaseType | None = None): ...
def iterator(self, database: _DatabaseType | None = None): ...
def __iter__(self): ...
def __getitem__(self, value): ...
def __len__(self) -> int: ...
Expand Down Expand Up @@ -575,20 +576,20 @@ class SelectQuery(Query):
def select_from(self, *columns) -> Select: ...

class SelectBase(_HashableSource, Source, SelectQuery): # type: ignore[misc]
def peek(self, database=None, n: int = 1): ...
def first(self, database=None, n: int = 1): ...
def scalar(self, database=None, as_tuple: bool = False, as_dict: bool = False): ...
def scalars(self, database=None) -> Generator[Incomplete]: ...
def count(self, database=None, clear_limit: bool = False) -> int: ...
def exists(self, database=None) -> bool: ...
def get(self, database=None): ...
def peek(self, database: _DatabaseType | None = None, n: int = 1): ...
def first(self, database: _DatabaseType | None = None, n: int = 1): ...
def scalar(self, database: _DatabaseType | None = None, as_tuple: bool = False, as_dict: bool = False): ...
def scalars(self, database: _DatabaseType | None = None) -> Generator[Incomplete]: ...
def count(self, database: _DatabaseType | None = None, clear_limit: bool = False) -> int: ...
def exists(self, database: _DatabaseType | None = None) -> bool: ...
def get(self, database: _DatabaseType | None = None): ...

class CompoundSelectQuery(SelectBase):
lhs: Incomplete
op: Incomplete
rhs: Incomplete
def __init__(self, lhs, op, rhs) -> None: ...
def exists(self, database=None) -> bool: ...
def exists(self, database: _DatabaseType | None = None) -> bool: ...
def __sql__(self, ctx): ...

class Select(SelectBase):
Expand Down Expand Up @@ -635,8 +636,8 @@ class _WriteQuery(Query):
def cte(self, name, recursive: bool = False, columns=None, materialized=None) -> CTE: ...
def returning(self, *returning) -> Self: ...
def apply_returning(self, ctx): ...
def execute_returning(self, database): ...
def handle_result(self, database, cursor): ...
def execute_returning(self, database: _DatabaseType): ...
def handle_result(self, database: _DatabaseType, cursor): ...
def __sql__(self, ctx): ...

class Update(_WriteQuery):
Expand All @@ -660,7 +661,7 @@ class Insert(_WriteQuery):
def get_default_data(self): ...
def get_default_columns(self) -> list[Incomplete] | None: ...
def __sql__(self, ctx): ...
def handle_result(self, database, cursor): ...
def handle_result(self, database: _DatabaseType, cursor): ...

class Delete(_WriteQuery):
def __sql__(self, ctx): ...
Expand Down Expand Up @@ -778,7 +779,7 @@ class Database(_callable_context_manager):
connect_params: Incomplete
def __init__(
self,
database,
database: str | None,
thread_safe: bool = True,
autorollback: bool = False,
field_types=None,
Expand All @@ -789,7 +790,7 @@ class Database(_callable_context_manager):
) -> None: ...
database: Incomplete
deferred: Incomplete
def init(self, database, **kwargs) -> None: ...
def init(self, database: str, **kwargs) -> None: ...
def __enter__(self) -> Self: ...
def __exit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
Expand Down Expand Up @@ -857,10 +858,10 @@ class SqliteDatabase(Database):
truncate_table: bool
nulls_ordering: bool
def __init__(
self, database, pragmas=None, regexp_function: bool = False, rank_functions: bool = False, *args, **kwargs
self, database: str | None, pragmas=None, regexp_function: bool = False, rank_functions: bool = False, *args, **kwargs
) -> None: ...
returning_clause: Incomplete
def init(self, database, pragmas=None, timeout: int = 5, returning_clause=None, **kwargs) -> None: ...
def init(self, database: str, pragmas=None, timeout: int = 5, returning_clause=None, **kwargs) -> None: ...
def pragma(self, key, value=..., permanent: bool = False, schema: str | None = None): ...
cache_size: Incomplete
foreign_keys: Incomplete
Expand Down Expand Up @@ -968,7 +969,7 @@ class PostgresqlDatabase(Database):
psycopg3_adapter: Incomplete
def init(
self,
database,
database: str | None,
register_unicode: bool = True,
encoding=None,
isolation_level=None,
Expand Down Expand Up @@ -1012,7 +1013,7 @@ class MySQLDatabase(Database):
safe_create_index: bool
safe_drop_index: bool
sql_mode: str
def init(self, database, **kwargs) -> None: ...
def init(self, database: str | None, **kwargs) -> None: ...
def is_connection_usable(self) -> bool: ...
def default_values_insert(self, ctx): ...
def begin(self, isolation_level: str | None = None) -> None: ...
Expand Down Expand Up @@ -1515,7 +1516,7 @@ class _SortedFieldList:
class SchemaManager:
model: Incomplete
context_options: Incomplete
def __init__(self, model, database=None, **context_options) -> None: ...
def __init__(self, model, database: _DatabaseType | None = None, **context_options) -> None: ...

@property
def database(self): ...
Expand Down Expand Up @@ -1569,7 +1570,7 @@ class Metadata:
def __init__(
self,
model,
database=None,
database: _DatabaseType | None = None,
table_name=None,
indexes=None,
primary_key=None,
Expand Down Expand Up @@ -1616,7 +1617,7 @@ class Metadata:
def get_primary_keys(self): ...
def get_default_dict(self): ...
def fields_to_index(self) -> list[Incomplete]: ...
def set_database(self, database) -> None: ...
def set_database(self, database: _DatabaseType) -> None: ...
def set_table_name(self, table_name) -> None: ...

class SubclassAwareMetadata(Metadata):
Expand Down Expand Up @@ -1644,7 +1645,7 @@ class _BoundModelsContext(_callable_context_manager):
database: Incomplete
bind_refs: Incomplete
bind_backrefs: Incomplete
def __init__(self, models, database, bind_refs, bind_backrefs) -> None: ...
def __init__(self, models, database: _DatabaseType, bind_refs, bind_backrefs) -> None: ...
def __enter__(self): ...
def __exit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
Expand Down Expand Up @@ -1713,9 +1714,9 @@ class Model(metaclass=ModelBase):
def __ne__(self, other) -> Expression | bool: ... # type: ignore[override]
def __sql__(self, ctx): ...
@classmethod
def bind(cls, database, bind_refs: bool = True, bind_backrefs: bool = True, _exclude=None) -> bool: ...
def bind(cls, database: _DatabaseType, bind_refs: bool = True, bind_backrefs: bool = True, _exclude=None) -> bool: ...
@classmethod
def bind_ctx(cls, database, bind_refs: bool = True, bind_backrefs: bool = True) -> _BoundModelsContext: ...
def bind_ctx(cls, database: _DatabaseType, bind_refs: bool = True, bind_backrefs: bool = True) -> _BoundModelsContext: ...
@classmethod
def table_exists(cls): ...
@classmethod
Expand Down Expand Up @@ -1774,8 +1775,8 @@ class BaseModelSelect(_ModelQueryHelper):
__sub__ = except_
def __iter__(self): ...
def prefetch(self, *subqueries): ...
def get(self, database=None): ...
def get_or_none(self, database=None): ...
def get(self, database: _DatabaseType | None = None): ...
def get_or_none(self, database: _DatabaseType | None = None): ...
def group_by(self, *columns) -> Self: ...

class ModelCompoundSelectQuery(BaseModelSelect, CompoundSelectQuery): # type: ignore[misc]
Expand Down
Loading