Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Change log

## Version 3.2.2
- Fixed:
- Better subclassing support: Determine classes dynamically,
so that methods like str() are aware when our types are subclassed.

## Version 3.2.1
- Fixed:
- Build system includes sortedcontainers dependency in the wheel again
Expand Down
35 changes: 22 additions & 13 deletions intervaltree/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class Interval(namedtuple('IntervalBase', ['begin', 'end', 'data'])):

def __new__(cls, begin, end, data=None):
return super(Interval, cls).__new__(cls, begin, end, data)

def overlaps(self, begin, end=None):
"""
Whether the interval overlaps the given point, range or Interval.
Expand All @@ -44,7 +44,7 @@ def overlaps(self, begin, end=None):
if end is not None:
# An overlap means that some C exists that is inside both ranges:
# begin <= C < end
# and
# and
# self.begin <= C < self.end
# See https://stackoverflow.com/questions/3269434/whats-the-most-efficient-way-to-test-two-integer-ranges-for-overlap/3269471#3269471
return begin < self.end and end > self.begin
Expand Down Expand Up @@ -84,7 +84,7 @@ def contains_point(self, p):
:rtype: bool
"""
return self.begin <= p < self.end

def range_matches(self, other):
"""
Whether the begins equal and the ends equal. Compare __eq__().
Expand All @@ -93,10 +93,10 @@ def range_matches(self, other):
:rtype: bool
"""
return (
self.begin == other.begin and
self.begin == other.begin and
self.end == other.end
)

def contains_interval(self, other):
"""
Whether other is contained in this Interval.
Expand All @@ -108,10 +108,10 @@ def contains_interval(self, other):
self.begin <= other.begin and
self.end >= other.end
)

def distance_to(self, other):
"""
Returns the size of the gap between intervals, or 0
Returns the size of the gap between intervals, or 0
if they touch or overlap.
:param other: Interval or point
:return: distance
Expand Down Expand Up @@ -291,7 +291,7 @@ def _get_fields(self):
return self.begin, self.end, self.data
else:
return self.begin, self.end

def __repr__(self):
"""
Executable string representation of this Interval.
Expand All @@ -305,9 +305,18 @@ def __repr__(self):
s_begin = repr(self.begin)
s_end = repr(self.end)
if self.data is None:
return "Interval({0}, {1})".format(s_begin, s_end)
return "{0}({1}, {2})".format(
self.__class__.__name__,
s_begin,
s_end,
)
else:
return "Interval({0}, {1}, {2})".format(s_begin, s_end, repr(self.data))
return "{0}({1}, {2}, {3})".format(
self.__class__.__name__,
s_begin,
s_end,
repr(self.data),
)

__str__ = __repr__

Expand All @@ -317,12 +326,12 @@ def copy(self):
:return: copy of self
:rtype: Interval
"""
return Interval(self.begin, self.end, self.data)
return self.__class__(self.begin, self.end, self.data)

def __reduce__(self):
"""
For pickle-ing.
:return: pickle data
:rtype: tuple
"""
return Interval, self._get_fields()
return self.__class__, self._get_fields()
20 changes: 10 additions & 10 deletions intervaltree/intervaltree.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def from_tuples(cls, tups):
where the tuple lists begin, end, and optionally data.
"""
ivs = [Interval(*t) for t in tups]
return IntervalTree(ivs)
return cls(ivs)

def __init__(self, intervals=None):
"""
Expand Down Expand Up @@ -277,7 +277,7 @@ def copy(self):
Completes in O(n*log n) time.
:rtype: IntervalTree
"""
return IntervalTree(iv.copy() for iv in self)
return self.__class__(iv.copy() for iv in self)

def _add_boundaries(self, interval):
"""
Expand Down Expand Up @@ -407,7 +407,7 @@ def difference(self, other):
for iv in self:
if iv not in other:
ivs.add(iv)
return IntervalTree(ivs)
return self.__class__(ivs)

def difference_update(self, other):
"""
Expand All @@ -421,7 +421,7 @@ def union(self, other):
Returns a new tree, comprising all intervals from self
and other.
"""
return IntervalTree(set(self).union(other))
return self.__class__(set(self).union(other))

def intersection(self, other):
"""
Expand All @@ -433,7 +433,7 @@ def intersection(self, other):
for iv in shorter:
if iv in longer:
ivs.add(iv)
return IntervalTree(ivs)
return self.__class__(ivs)

def intersection_update(self, other):
"""
Expand All @@ -452,7 +452,7 @@ def symmetric_difference(self, other):
if not isinstance(other, set): other = set(other)
me = set(self)
ivs = me.difference(other).union(other.difference(me))
return IntervalTree(ivs)
return self.__class__(ivs)

def symmetric_difference_update(self, other):
"""
Expand Down Expand Up @@ -1193,7 +1193,7 @@ def __eq__(self, other):
:rtype: bool
"""
return (
isinstance(other, IntervalTree) and
isinstance(other, self.__class__) and
self.all_intervals == other.all_intervals
)

Expand All @@ -1203,9 +1203,9 @@ def __repr__(self):
"""
ivs = sorted(self)
if not ivs:
return "IntervalTree()"
return "{0}()".format(self.__class__.__name__)
else:
return "IntervalTree({0})".format(ivs)
return "{0}({1})".format(self.__class__.__name__, ivs)

__str__ = __repr__

Expand All @@ -1214,5 +1214,5 @@ def __reduce__(self):
For pickle-ing.
:rtype: tuple
"""
return IntervalTree, (sorted(self.all_intervals),)
return self.__class__, (sorted(self.all_intervals),)

17 changes: 9 additions & 8 deletions intervaltree/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def from_interval(cls, interval):
:rtype : Node
"""
center = interval.begin
return Node(center, [interval])
return cls(center, [interval])

@classmethod
def from_intervals(cls, intervals):
Expand All @@ -71,7 +71,7 @@ def from_intervals(cls, intervals):
"""
if not intervals:
return None
return Node.from_sorted_intervals(sorted(intervals))
return cls.from_sorted_intervals(sorted(intervals))

@classmethod
def from_sorted_intervals(cls, intervals):
Expand All @@ -80,7 +80,7 @@ def from_sorted_intervals(cls, intervals):
"""
if not intervals:
return None
node = Node()
node = cls()
node = node.init_from_sorted(intervals)
return node

Expand All @@ -99,8 +99,8 @@ def init_from_sorted(self, intervals):
s_right.append(k)
else:
self.s_center.add(k)
self.left_node = Node.from_sorted_intervals(s_left)
self.right_node = Node.from_sorted_intervals(s_right)
self.left_node = self.__class__.from_sorted_intervals(s_left)
self.right_node = self.__class__.from_sorted_intervals(s_right)
return self.rotate()

def center_hit(self, interval):
Expand Down Expand Up @@ -212,7 +212,7 @@ def add(self, interval):
else:
direction = self.hit_branch(interval)
if not self[direction]:
self[direction] = Node.from_interval(interval)
self[direction] = self.__class__.from_interval(interval)
self.refresh_balance()
return self
else:
Expand Down Expand Up @@ -392,7 +392,7 @@ def get_new_s_center():
if iv.contains_point(new_x_center): yield iv

# Create a new node with the largest x_center possible.
child = Node(new_x_center, get_new_s_center())
child = self.__class__(new_x_center, get_new_s_center())
self.s_center -= child.s_center

#print('Pop hit! Returning child = {}'.format(
Expand Down Expand Up @@ -527,7 +527,8 @@ def __str__(self):
user, I'm not bothering to make this copy-paste-executable as a
constructor.
"""
return "Node<{0}, depth={1}, balance={2}>".format(
return "{0}<{1}, depth={2}, balance={3}>".format(
self.__class__.__name__,
self.x_center,
self.depth,
self.balance
Expand Down
29 changes: 29 additions & 0 deletions test/interval_methods/copy_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from intervaltree import Interval
import pickle

def test_copy():
iv0 = Interval(1, 2, 3)
iv1 = iv0.copy()
assert iv1.begin == iv0.begin
assert iv1.end == iv0.end
assert iv1.data == iv0.data
assert iv1 == iv0

iv2 = pickle.loads(pickle.dumps(iv0))
assert iv2.begin == iv0.begin
assert iv2.end == iv0.end
assert iv2.data == iv0.data
assert iv2 == iv0


def test_copy_type():
class MyInterval(Interval):
pass
iv = MyInterval(1, 2)
c = iv.copy()
assert isinstance(c, MyInterval)


if __name__ == "__main__":
import pytest
pytest.main([__file__, '-v'])
52 changes: 52 additions & 0 deletions test/interval_methods/str_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from intervaltree import Interval

def test_str():
iv = Interval(0, 1)
s = str(iv)
assert s == 'Interval(0, 1)'
assert repr(iv) == s

iv = Interval(0, 1, '[0,1)')
s = str(iv)
assert s == "Interval(0, 1, '[0,1)')"
assert repr(iv) == s

iv = Interval((1,2), (3,4))
s = str(iv)
assert s == 'Interval((1, 2), (3, 4))'
assert repr(iv) == s

iv = Interval((1,2), (3,4), (5, 6))
s = str(iv)
assert s == 'Interval((1, 2), (3, 4), (5, 6))'
assert repr(iv) == s


def test_str_type():
class MyInterval(Interval):
pass

iv = MyInterval(0, 1)
s = str(iv)
assert s == 'MyInterval(0, 1)'
assert repr(iv) == s

iv = MyInterval(0, 1, '[0,1)')
s = str(iv)
assert s == "MyInterval(0, 1, '[0,1)')"
assert repr(iv) == s

iv = MyInterval((1,2), (3,4))
s = str(iv)
assert s == 'MyInterval((1, 2), (3, 4))'
assert repr(iv) == s

iv = MyInterval((1,2), (3,4), (5, 6))
s = str(iv)
assert s == 'MyInterval((1, 2), (3, 4), (5, 6))'
assert repr(iv) == s


if __name__ == "__main__":
import pytest
pytest.main([__file__, '-v'])
37 changes: 0 additions & 37 deletions test/interval_methods/unary_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,6 @@ def test_isnull():
assert iv.is_null()


def test_copy():
iv0 = Interval(1, 2, 3)
iv1 = iv0.copy()
assert iv1.begin == iv0.begin
assert iv1.end == iv0.end
assert iv1.data == iv0.data
assert iv1 == iv0

iv2 = pickle.loads(pickle.dumps(iv0))
assert iv2.begin == iv0.begin
assert iv2.end == iv0.end
assert iv2.data == iv0.data
assert iv2 == iv0


def test_len():
iv = Interval(0, 0)
assert len(iv) == 3
Expand All @@ -72,28 +57,6 @@ def test_length():
assert iv.length() == 2.9


def test_str():
iv = Interval(0, 1)
s = str(iv)
assert s == 'Interval(0, 1)'
assert repr(iv) == s

iv = Interval(0, 1, '[0,1)')
s = str(iv)
assert s == "Interval(0, 1, '[0,1)')"
assert repr(iv) == s

iv = Interval((1,2), (3,4))
s = str(iv)
assert s == 'Interval((1, 2), (3, 4))'
assert repr(iv) == s

iv = Interval((1,2), (3,4), (5, 6))
s = str(iv)
assert s == 'Interval((1, 2), (3, 4), (5, 6))'
assert repr(iv) == s


def test_get_fields():
ivn = Interval(0, 1)
ivo = Interval(0, 1, 'hello')
Expand Down
Loading