Skip to content

Commit 7991a09

Browse files
committed
_ExclusionBound, full === detection, local version fixes
1 parent c74f5f0 commit 7991a09

2 files changed

Lines changed: 171 additions & 93 deletions

File tree

src/packaging/specifiers.py

Lines changed: 127 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -49,63 +49,78 @@ def _trim_release(release: tuple[int, ...]) -> tuple[int, ...]:
4949
return release if end == len(release) else release[:end]
5050

5151

52+
# Sentinel kinds for _ExclusionBound.
53+
_AFTER_LOCALS: Final[int] = 0 # sorts after V+local, before V.post0
54+
_AFTER_POSTS: Final[int] = 1 # sorts after V.postN, before next release
55+
56+
5257
@functools.total_ordering
53-
class _PostExcludeBound:
54-
"""A bound version that sorts after all post-releases of a base version.
58+
class _ExclusionBound:
59+
"""A synthetic bound that sorts between version families.
60+
61+
PEP 440 has exclusion rules that cannot be expressed with plain Version
62+
bounds. This sentinel encodes those rules into the version ordering so
63+
that interval arithmetic handles them correctly.
5564
56-
Used for ``>V`` (where V is not a post-release) to model PEP 440's rule
57-
that ``>V`` must not match post-releases of V.
65+
``_AFTER_LOCALS``: sorts after V and all V+local, before V.post0.
66+
Used for ``<=V``, ``==V``, ``!=V``, and ``>V.postN`` to correctly
67+
handle local versions.
5868
59-
Ordering::
69+
``_AFTER_POSTS``: sorts after all V.postN (and V+local), before the
70+
next release segment. Used for ``>V`` (non-post) to exclude
71+
post-releases per PEP 440.
6072
61-
1.0 < 1.0.post0 < 1.0.post1 < _PostExcludeBound(1.0) < 1.0.1.dev0
73+
Ordering for base version V::
6274
63-
This type is only used in interval bounds. Most bounds use plain Version
64-
objects for direct C-level comparison performance.
75+
V < V+local < AFTER_LOCALS(V) < V.post0 < ... < AFTER_POSTS(V) < V.0.1
6576
"""
6677

67-
__slots__ = ("_trimmed_release", "version")
78+
__slots__ = ("_kind", "_trimmed_release", "version")
6879

69-
def __init__(self, version: Version) -> None:
80+
def __init__(self, version: Version, kind: int) -> None:
7081
self.version = version
82+
self._kind = kind
7183
self._trimmed_release = _trim_release(version.release)
7284

73-
def _is_post_family(self, other: Version) -> bool:
74-
"""Is ``other`` the same version as self.version, or a post-release of it?
75-
76-
A "post-release of V" shares V's epoch, release, and pre segment
77-
but adds a post (and possibly dev-of-post) segment.
78-
"""
85+
def _is_family(self, other: Version) -> bool:
86+
"""Is ``other`` a version that this sentinel sorts above?"""
7987
v = self.version
80-
return (
88+
if not (
8189
other.epoch == v.epoch
8290
and _trim_release(other.release) == self._trimmed_release
8391
and other.pre == v.pre
84-
and (other.dev == v.dev or other.post is not None)
85-
)
92+
):
93+
return False
94+
if self._kind == _AFTER_LOCALS:
95+
# Local family: exact same public version (any local label).
96+
return other.post == v.post and other.dev == v.dev
97+
# Post family: same base + any post-release (or identical).
98+
return other.dev == v.dev or other.post is not None
8699

87100
def __eq__(self, other: object) -> bool:
88-
if isinstance(other, _PostExcludeBound):
89-
return self.version == other.version
101+
if isinstance(other, _ExclusionBound):
102+
return self.version == other.version and self._kind == other._kind
90103
return NotImplemented
91104

92105
def __lt__(self, other: object) -> bool:
93-
if isinstance(other, _PostExcludeBound):
94-
return self.version < other.version
106+
if isinstance(other, _ExclusionBound):
107+
if self.version != other.version:
108+
return self.version < other.version
109+
return self._kind < other._kind
95110
assert isinstance(other, Version)
96-
# self < other iff other is NOT in the post-family and other > V
97-
return not self._is_post_family(other) and self.version < other
111+
# self < other iff other is NOT in the family and other > V
112+
return not self._is_family(other) and self.version < other
98113

99114
def __hash__(self) -> int:
100-
return hash(self.version)
115+
return hash((self.version, self._kind))
101116

102117

103118
if typing.TYPE_CHECKING:
104119
from typing_extensions import TypeAlias
105120

106-
# Bound version: plain Version for most operators, _PostExcludeBound
107-
# for >V (sorts after V.postN), or None for unbounded.
108-
_BoundVersion: TypeAlias = Union[Version, _PostExcludeBound, None]
121+
# Bound version: plain Version for most operators, _ExclusionBound
122+
# for local/post exclusion zones, or None for unbounded.
123+
_BoundVersion: TypeAlias = Union[Version, _ExclusionBound, None]
109124

110125
# A specifier bound: (bound_version, inclusive).
111126
_SpecifierBound: TypeAlias = tuple[_BoundVersion, bool]
@@ -504,10 +519,8 @@ def _require_spec_version(self, version: str) -> Version:
504519
def _to_intervals(self) -> list[_SpecifierInterval]:
505520
"""Convert this specifier to sorted, non-overlapping intervals.
506521
507-
Each standard operator maps to one or two intervals. The ``===``
508-
operator is modeled as full range since it uses arbitrary string
509-
matching; the actual check is done separately.
510-
Result is cached.
522+
Each standard operator maps to one or two intervals. ``===`` is
523+
modeled as full range (actual check done separately). Cached.
511524
"""
512525
if self._intervals is not None:
513526
return self._intervals
@@ -519,51 +532,62 @@ def _to_intervals(self) -> list[_SpecifierInterval]:
519532
self._intervals = _FULL_RANGE
520533
return _FULL_RANGE
521534

522-
result: list[_SpecifierInterval]
523-
524535
if ver_str.endswith(".*"):
525-
# Wildcard bounds: ==1.2.* matches [1.2.dev0, 1.3.dev0).
526-
base = self._require_spec_version(ver_str[:-2])
527-
lower = _base_dev0(base)
528-
upper = _next_prefix_dev0(base)
529-
if op == "==":
530-
result = [((lower, True), (upper, False))]
531-
else: # !=
532-
result = [
533-
((None, False), (lower, False)),
534-
((upper, True), (None, False)),
535-
]
536+
result = self._wildcard_intervals(op, ver_str)
536537
else:
537-
v = self._require_spec_version(ver_str)
538-
if op == ">=":
539-
result = [((v, True), (None, False))]
540-
elif op == "<=":
541-
result = [((None, False), (v, True))]
542-
elif op == ">":
543-
if v.is_postrelease:
544-
result = [((v, False), (None, False))]
545-
else:
546-
result = [((_PostExcludeBound(v), False), (None, False))]
547-
elif op == "<":
548-
bound = v if v.is_prerelease else _base_dev0(v)
549-
if bound <= _MIN_VERSION:
550-
result = []
551-
else:
552-
result = [((None, False), (bound, False))]
553-
elif op == "==":
554-
result = [((v, True), (v, True))]
555-
elif op == "!=":
556-
result = [((None, False), (v, False)), ((v, False), (None, False))]
557-
elif op == "~=":
558-
prefix = v.__replace__(release=v.release[:-1])
559-
upper = _next_prefix_dev0(prefix)
560-
result = [((v, True), (upper, False))]
561-
else:
562-
raise ValueError(f"Unknown operator: {op!r}") # pragma: no cover
538+
result = self._standard_intervals(op, ver_str)
563539

564540
self._intervals = result
565541
return result
566542

543+
def _wildcard_intervals(self, op: str, ver_str: str) -> list[_SpecifierInterval]:
544+
# ==1.2.* -> [1.2.dev0, 1.3.dev0); !=1.2.* -> complement.
545+
base = self._require_spec_version(ver_str[:-2])
546+
lower = _base_dev0(base)
547+
upper = _next_prefix_dev0(base)
548+
if op == "==":
549+
return [((lower, True), (upper, False))]
550+
# !=
551+
return [((None, False), (lower, False)), ((upper, True), (None, False))]
552+
553+
def _standard_intervals(self, op: str, ver_str: str) -> list[_SpecifierInterval]:
554+
v = self._require_spec_version(ver_str)
555+
has_local = "+" in ver_str
556+
after_locals = _ExclusionBound(v, _AFTER_LOCALS)
557+
558+
if op == ">=":
559+
return [((v, True), (None, False))]
560+
561+
if op == "<=":
562+
return [((None, False), (after_locals, True))]
563+
564+
if op == ">":
565+
# >V.postN: skip V.postN+local. >V: skip all V.postN too.
566+
kind = _AFTER_LOCALS if v.is_postrelease else _AFTER_POSTS
567+
return [((_ExclusionBound(v, kind), False), (None, False))]
568+
569+
if op == "<":
570+
bound = v if v.is_prerelease else _base_dev0(v)
571+
if bound <= _MIN_VERSION:
572+
return []
573+
return [((None, False), (bound, False))]
574+
575+
if op == "==":
576+
# ==V (no local) matches V+local; ==V+local matches exactly.
577+
eq_upper: Version | _ExclusionBound = v if has_local else after_locals
578+
return [((v, True), (eq_upper, True))]
579+
580+
if op == "!=":
581+
# !=V (no local) excludes V+local; !=V+local excludes exactly.
582+
ne_upper: Version | _ExclusionBound = v if has_local else after_locals
583+
return [((None, False), (v, False)), ((ne_upper, False), (None, False))]
584+
585+
if op == "~=":
586+
prefix = v.__replace__(release=v.release[:-1])
587+
return [((v, True), (_next_prefix_dev0(prefix), False))]
588+
589+
raise ValueError(f"Unknown operator: {op!r}") # pragma: no cover
590+
567591
@property
568592
def prereleases(self) -> bool | None:
569593
# If there is an explicit prereleases set for this, then we'll just
@@ -1366,8 +1390,6 @@ def is_unsatisfiable(self) -> bool:
13661390
"""Check whether this specifier set can never be satisfied.
13671391
13681392
Returns True if no version can satisfy all specifiers simultaneously.
1369-
Returns False if the set might be satisfiable (conservative: may
1370-
return False for some unsatisfiable sets involving === specifiers).
13711393
13721394
>>> SpecifierSet(">=2.0,<1.0").is_unsatisfiable()
13731395
True
@@ -1388,17 +1410,41 @@ def is_unsatisfiable(self) -> bool:
13881410

13891411
result = not self._get_intervals()
13901412

1391-
# Extra: === with an unparsable version can only match raw
1392-
# strings, but standard specs reject raw strings.
1393-
if not result and any(
1394-
s.operator == "===" and _coerce_version(s.version) is None
1395-
for s in self._specs
1396-
):
1397-
result = any(s.operator != "===" for s in self._specs)
1413+
if not result:
1414+
result = self._check_arbitrary_unsatisfiable()
13981415

13991416
self._is_unsatisfiable = result
14001417
return result
14011418

1419+
def _check_arbitrary_unsatisfiable(self) -> bool:
1420+
"""Check === (arbitrary equality) specs for unsatisfiability.
1421+
1422+
=== uses case-insensitive string comparison, so the only candidate
1423+
that can match ``===V`` is the literal string V. This method
1424+
checks whether that candidate is excluded by other specifiers.
1425+
"""
1426+
arbitrary = [s for s in self._specs if s.operator == "==="]
1427+
if not arbitrary:
1428+
return False
1429+
1430+
# Multiple === must agree on the same string (case-insensitive).
1431+
first = arbitrary[0].version.lower()
1432+
if any(s.version.lower() != first for s in arbitrary[1:]):
1433+
return True
1434+
1435+
# The sole candidate is the === version string. Check whether
1436+
# it can satisfy every standard spec.
1437+
standard = [s for s in self._specs if s.operator != "==="]
1438+
if not standard:
1439+
return False
1440+
1441+
candidate = _coerce_version(arbitrary[0].version)
1442+
if candidate is None:
1443+
# Unparsable string cannot satisfy any standard spec.
1444+
return True
1445+
1446+
return not all(s.contains(candidate) for s in standard)
1447+
14021448
def __contains__(self, item: UnparsedVersion) -> bool:
14031449
"""Return whether or not the item is contained in this specifier.
14041450

0 commit comments

Comments
 (0)