Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,6 @@ jobs:
- uses: dtolnay/rust-toolchain@stable
with:
targets: ${{ matrix.platform.rust-target }}
components: rust-src
- uses: actions/setup-python@v6
with:
python-version: "3.14"
Expand All @@ -784,6 +783,18 @@ jobs:
needs: [fmt]
if: ${{ !contains(github.event.pull_request.labels.*.name, 'CI-build-full') && github.event_name == 'pull_request' }}
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6.0.2
- uses: dtolnay/rust-toolchain@stable
- uses: actions/setup-python@v6
with:
python-version: "3.14"
- run: python -m pip install --upgrade pip && pip install nox[uv]
- run: nox -s test-introspection

mypy-pytests:
needs: [fmt]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6.0.2
- uses: dtolnay/rust-toolchain@stable
Expand All @@ -793,7 +804,8 @@ jobs:
with:
python-version: "3.14"
- run: python -m pip install --upgrade pip && pip install nox[uv]
- run: nox -s test-introspection
- run: nox -s mypy
working-directory: pytests

conclusion:
needs:
Expand Down
29 changes: 29 additions & 0 deletions pytests/noxfile.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import shutil
from pathlib import Path

import nox
import sys
from nox.command import CommandFailed
Expand Down Expand Up @@ -34,3 +37,29 @@ def try_install_binary(package: str, constraint: str):
def bench(session: nox.Session):
session.install(".[dev]")
session.run("pytest", "--benchmark-enable", "--benchmark-only", *session.posargs)


@nox.session
def mypy(session: nox.Session):
session.env["MATURIN_PEP517_ARGS"] = "--profile=dev"
try:
# We move the stubs where maturin is expecting them to be
shutil.copytree("stubs", "pyo3_pytests")
Comment thread
Tpt marked this conversation as resolved.
(Path("pyo3_pytests") / "py.typed").touch()
session.install(".[dev]")

# TODO: remove --disable-error-code", "override" when __eq__ and __ne__ will always take object for input
# TODO: remove "--disable-error-code", "misc" when #[classattr] will be properly emitted
session.run_always(
"python",
"-m",
"mypy",
"tests",
"--disable-error-code",
"override",
"--disable-error-code",
"misc",
)
# TODO: enable stubtest when previously listed errors will be fixed session.run_always("python", "-m", "mypy.stubtest", "pyo3_pytests")
finally:
shutil.rmtree("pyo3_pytests")
1 change: 1 addition & 0 deletions pytests/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ classifiers = [
[project.optional-dependencies]
dev = [
"hypothesis>=3.55",
"mypy~=1.0",
"pytest-asyncio>=0.21,<2",
"pytest-benchmark>=3.4",
"pytest>=7",
Expand Down
17 changes: 11 additions & 6 deletions pytests/src/datetime.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
#![cfg(not(Py_LIMITED_API))]

use pyo3::prelude::*;
use pyo3::types::{
PyDate, PyDateAccess, PyDateTime, PyDelta, PyDeltaAccess, PyTime, PyTimeAccess, PyTuple,
Expand Down Expand Up @@ -190,15 +188,22 @@ impl TzClass {
TzClass {}
}

fn utcoffset<'py>(&self, dt: &Bound<'py, PyDateTime>) -> PyResult<Bound<'py, PyDelta>> {
PyDelta::new(dt.py(), 0, 3600, 0, true)
#[pyo3(signature = (_dt, /))]
fn utcoffset<'py>(
&self,
_dt: Option<&Bound<'_, PyDateTime>>,
Comment thread
Tpt marked this conversation as resolved.
py: Python<'py>,
) -> PyResult<Bound<'py, PyDelta>> {
PyDelta::new(py, 0, 3600, 0, true)
}

fn tzname(&self, _dt: &Bound<'_, PyDateTime>) -> String {
#[pyo3(signature = (_dt, /))]
fn tzname(&self, _dt: Option<&Bound<'_, PyDateTime>>) -> String {
String::from("+01:00")
}

fn dst<'py>(&self, _dt: &Bound<'py, PyDateTime>) -> Option<Bound<'py, PyDelta>> {
#[pyo3(signature = (_dt, /))]
fn dst(&self, _dt: Option<&Bound<'_, PyDateTime>>) -> Option<Bound<'static, PyDelta>> {
None
}
}
Expand Down
1 change: 1 addition & 0 deletions pytests/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod awaitable;
mod buf_and_str;
mod comparisons;
mod consts;
#[cfg(not(Py_LIMITED_API))]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Over time we have increased the amount of datetime module available with Py_LIMITED_API, this may not be necessary any more.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The accessing traits (PyDateAccess...) are sadly not supported yet with Py_LIMITED_API

mod datetime;
mod dict_iter;
mod enums;
Expand Down
6 changes: 3 additions & 3 deletions pytests/stubs/datetime.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ from typing import final
@final
class TzClass(tzinfo):
def __new__(cls, /) -> TzClass: ...
def dst(self, /, _dt: datetime) -> timedelta | None: ...
def tzname(self, /, _dt: datetime) -> str: ...
def utcoffset(self, /, dt: datetime) -> timedelta: ...
def dst(self, _dt: datetime | None, /) -> timedelta | None: ...
def tzname(self, _dt: datetime | None, /) -> str: ...
def utcoffset(self, _dt: datetime | None, /) -> timedelta: ...

def date_from_timestamp(timestamp: int) -> date: ...
def datetime_from_timestamp(ts: float, tz: tzinfo | None = None) -> datetime: ...
Expand Down
58 changes: 38 additions & 20 deletions pytests/tests/test_comparisons.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Type, Union
from typing import Type, TypeVar

import sys
import pytest
Expand All @@ -24,22 +24,25 @@ def __eq__(self, other: object) -> bool:
else:
return NotImplemented

def __ne__(self, other: Self) -> bool:
def __ne__(self, other: object) -> bool:
if isinstance(other, self.__class__):
return self.x != other.x
else:
return NotImplemented


EqType = TypeVar("EqType", Eq, EqDerived, PyEq)


@pytest.mark.skipif(
sys.implementation.name == "graalpy"
and __graalpython__.get_graalvm_version().startswith("24.1"), # noqa: F821
and __graalpython__.get_graalvm_version().startswith("24.1"), # type: ignore[name-defined] # noqa: F821
reason="Bug in GraalPy 24.1",
)
@pytest.mark.parametrize(
"ty", (Eq, EqDerived, PyEq), ids=("rust", "rust-derived", "python")
)
def test_eq(ty: Type[Union[Eq, EqDerived, PyEq]]):
def test_eq(ty: Type[EqType]):
a = ty(0)
b = ty(0)
c = ty(1)
Expand All @@ -62,28 +65,31 @@ def test_eq(ty: Type[Union[Eq, EqDerived, PyEq]]):
assert c != 1

with pytest.raises(TypeError):
assert a <= b
assert a <= b # type: ignore[operator]

with pytest.raises(TypeError):
assert a >= b
assert a >= b # type: ignore[operator]

with pytest.raises(TypeError):
assert a < c
assert a < c # type: ignore[operator]

with pytest.raises(TypeError):
assert c > a
assert c > a # type: ignore[operator]


class PyEqDefaultNe:
def __init__(self, x: int) -> None:
self.x = x

def __eq__(self, other: Self) -> bool:
return self.x == other.x
def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__) and self.x == other.x


EqDefaultType = TypeVar("EqDefaultType", EqDefaultNe, PyEqDefaultNe)


@pytest.mark.parametrize("ty", (EqDefaultNe, PyEqDefaultNe), ids=("rust", "python"))
def test_eq_default_ne(ty: Type[Union[EqDefaultNe, PyEqDefaultNe]]):
def test_eq_default_ne(ty: Type[EqDefaultType]):
a = ty(0)
b = ty(0)
c = ty(1)
Expand All @@ -99,16 +105,16 @@ def test_eq_default_ne(ty: Type[Union[EqDefaultNe, PyEqDefaultNe]]):
assert not (b == c)

with pytest.raises(TypeError):
assert a <= b
assert a <= b # type: ignore[operator]

with pytest.raises(TypeError):
assert a >= b
assert a >= b # type: ignore[operator]

with pytest.raises(TypeError):
assert a < c
assert a < c # type: ignore[operator]

with pytest.raises(TypeError):
assert c > a
assert c > a # type: ignore[operator]


class PyOrdered:
Expand All @@ -121,10 +127,14 @@ def __lt__(self, other: Self) -> bool:
def __le__(self, other: Self) -> bool:
return self.x <= other.x

def __eq__(self, other: Self) -> bool:
def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return NotImplemented
return self.x == other.x

def __ne__(self, other: Self) -> bool:
def __ne__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return NotImplemented
return self.x != other.x

def __gt__(self, other: Self) -> bool:
Expand All @@ -134,12 +144,15 @@ def __ge__(self, other: Self) -> bool:
return self.x >= other.x


OrderedType = TypeVar("OrderedType", Ordered, OrderedDerived, OrderedRichCmp, PyOrdered)


@pytest.mark.parametrize(
"ty",
(Ordered, OrderedDerived, OrderedRichCmp, PyOrdered),
ids=("rust", "rust-derived", "rust-richcmp", "python"),
)
def test_ordered(ty: Type[Union[Ordered, OrderedDerived, OrderedRichCmp, PyOrdered]]):
def test_ordered(ty: Type[OrderedType]):
a = ty(0)
b = ty(0)
c = ty(1)
Expand Down Expand Up @@ -174,7 +187,9 @@ def __lt__(self, other: Self) -> bool:
def __le__(self, other: Self) -> bool:
return self.x <= other.x

def __eq__(self, other: Self) -> bool:
def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return NotImplemented
return self.x == other.x

def __gt__(self, other: Self) -> bool:
Expand All @@ -184,10 +199,13 @@ def __ge__(self, other: Self) -> bool:
return self.x >= other.x


OrderedDefaultType = TypeVar("OrderedDefaultType", OrderedDefaultNe, PyOrderedDefaultNe)


@pytest.mark.parametrize(
"ty", (OrderedDefaultNe, PyOrderedDefaultNe), ids=("rust", "python")
)
def test_ordered_default_ne(ty: Type[Union[OrderedDefaultNe, PyOrderedDefaultNe]]):
def test_ordered_default_ne(ty: Type[OrderedDefaultType]):
a = ty(0)
b = ty(0)
c = ty(1)
Expand Down
48 changes: 16 additions & 32 deletions pytests/tests/test_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,28 +73,20 @@ def test_complex_enum_field_getters():
)
def test_complex_enum_desugared_match(variant: enums.ComplexEnum):
if isinstance(variant, enums.ComplexEnum.Int):
x = variant.i
assert x == 42
assert variant.i == 42
elif isinstance(variant, enums.ComplexEnum.Float):
x = variant.f
assert x == 3.14
assert variant.f == 3.14
elif isinstance(variant, enums.ComplexEnum.Str):
x = variant.s
assert x == "hello"
assert variant.s == "hello"
elif isinstance(variant, enums.ComplexEnum.EmptyStruct):
assert True
elif isinstance(variant, enums.ComplexEnum.MultiFieldStruct):
x = variant.a
y = variant.b
z = variant.c
assert x == 42
assert y == 3.14
assert z is True
assert variant.a == 42
assert variant.b == 3.14
assert variant.c is True
elif isinstance(variant, enums.ComplexEnum.VariantWithDefault):
x = variant.a
y = variant.b
assert x == 42
assert y is None
assert variant.a == 42
assert variant.b is None
else:
assert False

Expand All @@ -113,28 +105,20 @@ def test_complex_enum_desugared_match(variant: enums.ComplexEnum):
def test_complex_enum_pyfunction_in_out_desugared_match(variant: enums.ComplexEnum):
variant = enums.do_complex_stuff(variant)
if isinstance(variant, enums.ComplexEnum.Int):
x = variant.i
assert x == 5
assert variant.i == 5
elif isinstance(variant, enums.ComplexEnum.Float):
x = variant.f
assert x == 9.8596
assert variant.f == 9.8596
elif isinstance(variant, enums.ComplexEnum.Str):
x = variant.s
assert x == "42"
assert variant.s == "42"
elif isinstance(variant, enums.ComplexEnum.EmptyStruct):
assert True
elif isinstance(variant, enums.ComplexEnum.MultiFieldStruct):
x = variant.a
y = variant.b
z = variant.c
assert x == 42
assert y == 3.14
assert z is True
assert variant.a == 42
assert variant.b == 3.14
assert variant.c is True
elif isinstance(variant, enums.ComplexEnum.VariantWithDefault):
x = variant.a
y = variant.b
assert x == 84
assert y == "HELLO"
assert variant.a == 84
assert variant.b == "HELLO"
else:
assert False

Expand Down
Loading