Skip to content
72 changes: 72 additions & 0 deletions testing/src/scenario/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class _StateKwargs(TypedDict, total=False):
UnitID = int

CharmType = TypeVar('CharmType', bound=CharmBase)
_RelationType = TypeVar('_RelationType', bound='RelationBase')

logger = scenario_logger.getChild('state')

Expand Down Expand Up @@ -1746,6 +1747,38 @@ def get_relation(self, relation: int, /) -> RelationBase:
return state_relation
raise KeyError(f'relation: id={relation} not found in the State')

def get_regular_relation(self, relation: int, /) -> Relation:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really love the word "regular here", but I'm also not sure what to suggest instead.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I don't love it either. I went with it because the Juju docs on relations basically go Peer, Non-Peer, Subordinate, Non-Subordinate, and then:

A non-subordinate relation (aka ‘regular’) is a non-peer relation where the applications are both principal.

I'm guessing that's why Dima suggested it too. Open to alternatives though.

"""Get a regular relation from this State by relation id.

Raises:
TypeError: If this relation is not a ``Relation``.
"""
return self._get_typed_relation(relation, kind=Relation)

def get_peer_relation(self, relation: int, /) -> PeerRelation:
"""Get a peer relation from this State by relation id.

Raises:
TypeError: If this relation is not a ``PeerRelation``.
"""
return self._get_typed_relation(relation, kind=PeerRelation)

def get_subordinate_relation(self, relation: int, /) -> SubordinateRelation:
"""Get a subordinate relation from this State by relation id.

Raises:
TypeError: If this relation is not a ``SubordinateRelation``.
"""
return self._get_typed_relation(relation, kind=SubordinateRelation)

def _get_typed_relation(self, relation: int, kind: type[_RelationType]) -> _RelationType:
rel = self.get_relation(relation)
if not isinstance(rel, kind):
raise TypeError(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this should be:

  • a TypeError (wrong type of relation about to be returned), or
  • a ValueError (wrong relation: int argument passed in)

f"Relation {relation} is not a {kind.__name__}, it's a {rel.__class__.__name__}"
)
return rel

def get_relations(self, endpoint: str) -> tuple[RelationBase, ...]:
"""Get all relations on this endpoint from the current state."""
# we rather normalize the endpoint than worry about cursed metadata situations such as:
Expand All @@ -1758,6 +1791,45 @@ def get_relations(self, endpoint: str) -> tuple[RelationBase, ...]:
r for r in self.relations if _normalise_name(r.endpoint) == normalized_endpoint
)

def get_regular_relations(self, endpoint: str) -> tuple[Relation, ...]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not convinced about adding the plural/endpoint ones. I'm pretty sure when we've looked at use of this in the past, it's almost always code like:

state_in = State(..., relations={relation})
state_out = ctx.run(..., state=state_in)
assert state_out.get_relations('endpoint')[0]...

And that is cleaner as:

state_in = State(..., relations={relation})
state_out = ctx.run(..., state=state_in)
assert state_out.get_relation(relation.id)...

I had a quick look over a few random selections in charms, and they were all of that type. I'm sure there are exceptions, but this is already adding 3 new methods for getting by ID - I'm not convinced there is enough value to add another 3 as well.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. I'll let @dimaqq weigh in too, but I'm happy to remove the plural ones.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose my one hesitation would be that getting relations by endpoint name rather than by ID feels more like the right level of expression for state-transition tests. But the current usage argument is pretty compelling.

"""Get regular relations on this endpoint from the current state.

Raises:
TypeError: If any of the relations on this endpoint is not
a ``Relation``.
"""
return self._get_typed_relations(endpoint, kind=Relation)

def get_peer_relations(self, endpoint: str) -> tuple[PeerRelation, ...]:
"""Get peer relations on this endpoint from the current state.

Raises:
TypeError: If any of the relations on this endpoint is not
a ``PeerRelation``.
"""
return self._get_typed_relations(endpoint, kind=PeerRelation)

def get_subordinate_relations(self, endpoint: str) -> tuple[SubordinateRelation, ...]:
"""Get subordinate relations on this endpoint from the current state.

Raises:
TypeError: If any of the relations on this endpoint is not
a ``SubordinateRelation``.
"""
return self._get_typed_relations(endpoint, kind=SubordinateRelation)

def _get_typed_relations(
self, endpoint: str, kind: type[_RelationType]
) -> tuple[_RelationType, ...]:
rels = self.get_relations(endpoint)
for rel in rels:
if not isinstance(rel, kind):
Comment on lines +1825 to +1826
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A zero-length return is valid, isn't it?
Then there's a possibility of a false positive, that is not raising an exception when we would have if there was a relation.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A zero-length return is valid, isn't it? Then there's a possibility of a false positive, that is not raising an exception when we would have if there was a relation.

That's correct, though I don't think we can do any better here. The Context knows how the relation is defined in the metadata, but the State does not, it only knows what's in the current state.

raise TypeError(
f'Relation on endpoint {endpoint} is not a {kind.__name__}'
f", it's a {rel.__class__.__name__}"
)
return rels

@classmethod
def from_context(
cls,
Expand Down
111 changes: 111 additions & 0 deletions testing/tests/test_e2e/test_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,117 @@ def _on_config_changed(self, _):
assert isinstance(exc.value, KeyError) or isinstance(exc.value.__cause__, KeyError)


def test_get_regular_relation():
rel = Relation('foo')
state = State(relations={rel})
rel_out = state.get_regular_relation(rel.id)
assert rel_out == rel


def test_get_regular_relation_type_error():
rel = PeerRelation('foo')
state = State(relations={rel})
with pytest.raises(TypeError):
state.get_regular_relation(rel.id)


def test_get_peer_relation():
rel = PeerRelation('peers')
state = State(relations={rel})
rel_out = state.get_peer_relation(rel.id)
assert rel_out == rel


def test_get_peer_relation_type_error():
rel = Relation('foo')
state = State(relations={rel})
with pytest.raises(TypeError):
state.get_peer_relation(rel.id)


def test_get_subordinate_relation():
rel = SubordinateRelation('logging')
state = State(relations={rel})
rel_out = state.get_subordinate_relation(rel.id)
assert rel_out == rel


def test_get_subordinate_relation_type_error():
rel = PeerRelation('peers')
state = State(relations={rel})
with pytest.raises(TypeError):
state.get_subordinate_relation(rel.id)


def test_get_regular_relations():
rel1 = Relation('foo')
rel2 = Relation('foo')
state = State(relations={rel1, rel2})
rels_out = state.get_regular_relations('foo')
assert len(rels_out) == 2
assert rel1 in rels_out
assert rel2 in rels_out


def test_get_regular_relations_type_error():
rel = PeerRelation('peers')
state = State(relations={rel})
with pytest.raises(TypeError):
state.get_regular_relations('peers')


def test_get_peer_relations():
rel1 = PeerRelation('peers')
rel2 = PeerRelation('peers')
state = State(relations={rel1, rel2})
rels_out = state.get_peer_relations('peers')
assert len(rels_out) == 2
assert rel1 in rels_out
assert rel2 in rels_out


def test_get_peer_relations_type_error():
rel = SubordinateRelation('logging')
state = State(relations={rel})
with pytest.raises(TypeError):
state.get_peer_relations('logging')


def test_get_subordinate_relations():
rel1 = SubordinateRelation('logging')
rel2 = SubordinateRelation('logging')
state = State(relations={rel1, rel2})
rels_out = state.get_subordinate_relations('logging')
assert len(rels_out) == 2
assert rel1 in rels_out
assert rel2 in rels_out


def test_get_subordinate_relations_type_error():
rel = Relation('foo')
state = State(relations={rel})
with pytest.raises(TypeError):
state.get_subordinate_relations('foo')


def test_get_regular_relations_empty():
state = State()
rels_out = state.get_regular_relations('foo')
assert len(rels_out) == 0


def test_get_peer_relations_empty():
state = State()
rels_out = state.get_peer_relations('peers')
assert len(rels_out) == 0


def test_get_subordinate_relations_empty():
state = State()
rels_out = state.get_subordinate_relations('logging')
assert len(rels_out) == 0


@pytest.mark.parametrize('klass', (Relation, PeerRelation, SubordinateRelation))
def test_relation_positional_arguments(klass):
with pytest.raises(TypeError):
Expand Down