Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions matrix/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Awaitable, TypeVar

from nio import ErrorResponse, Response

from matrix.errors import MatrixError

T = TypeVar("T", bound=Response)


async def matrix_call(coro: Awaitable[T], /, *, error_message: str) -> T:
"""Await `coro`, translating any failure into a `MatrixError`.

matrix-nio's `AsyncClient` methods don't raise on API-level errors; they
return an `ErrorResponse` instead of raising. This wraps a single call so
both transport-level exceptions and nio `ErrorResponse` results become a
`MatrixError` carrying `error_message`.

## Example

```python
response = await matrix_call(
self.client.room_kick(room_id=self.room_id, user_id=user_id),
error_message="Failed to kick user",
)
```
"""
try:
response = await coro
except Exception as e:
raise MatrixError(f"{error_message}: {e}") from e

if isinstance(response, ErrorResponse):
raise MatrixError(f"{error_message}: {response}")

return response
6 changes: 5 additions & 1 deletion matrix/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
CheckError,
RoomNotFoundError,
)
from .api import matrix_call


class Bot(Registry):
Expand Down Expand Up @@ -328,7 +329,10 @@ async def run(self) -> None:
if self.config.token:
self.client.access_token = self.config.token
else:
login_resp = await self.client.login(self.config.password)
login_resp = await matrix_call(
self.client.login(self.config.password),
error_message="Failed to log in",
)
self.log.info("logged in: %s", login_resp)

sync_task = asyncio.create_task(self.client.sync_forever(timeout=30_000))
Expand Down
33 changes: 17 additions & 16 deletions matrix/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from matrix.types import Reaction
from matrix.content import ReactionContent, EditContent
from matrix.errors import MatrixError
from matrix.api import matrix_call

if TYPE_CHECKING:
from .room import Room # pragma: no cover
Expand Down Expand Up @@ -116,14 +117,14 @@ async def thumbsup(ctx: Context):
"""
content = ReactionContent(event_id=self.event_id, emoji=emoji)

try:
await self.client.room_send(
await matrix_call(
self.client.room_send(
room_id=self.room.room_id,
message_type="m.reaction",
content=content.build(),
)
except Exception as e:
raise MatrixError(f"Failed to add reaction: {e}")
),
error_message="Failed to add reaction",
)

async def edit(self, new_body: str) -> None:
"""Updates the message content to the new text.
Expand All @@ -139,15 +140,15 @@ async def typo(ctx: Context):
"""
content = EditContent(new_body, original_event_id=self.event_id)

try:
await self.client.room_send(
await matrix_call(
self.client.room_send(
room_id=self.room.room_id,
message_type="m.room.message",
content=content.build(),
)
self._body = new_body
except Exception as e:
raise MatrixError(f"Failed to edit message: {e}")
),
error_message="Failed to edit message",
)
self._body = new_body

async def delete(self, reason: str | None = None) -> None:
"""Removes the message content from the room. This action cannot be undone.
Expand All @@ -166,11 +167,11 @@ async def oops(ctx: Context):
await message.delete(reason="Violated room rules")
```
"""
try:
await self.client.room_redact(
await matrix_call(
self.client.room_redact(
room_id=self.room.room_id,
event_id=self.event_id,
reason=reason,
)
except Exception as e:
raise MatrixError(f"Failed to delete message: {e}")
),
error_message="Failed to delete message",
)
93 changes: 43 additions & 50 deletions matrix/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from nio import AsyncClient, MatrixRoom, Event

from matrix.errors import MatrixError
from matrix.api import matrix_call
from matrix.message import Message
from matrix.content import (
BaseMessageContent,
Expand Down Expand Up @@ -329,21 +329,21 @@ async def send_file(self, file: File) -> Message:

async def _send_payload(self, payload: BaseMessageContent) -> Message:
"""Send a BaseMessageContent payload and return a Message object."""
try:
resp = await self.client.room_send(
resp = await matrix_call(
self.client.room_send(
room_id=self.room_id,
message_type="m.room.message",
content=payload.build(),
)
event = await self.fetch_event(resp.event_id)
),
error_message="Failed to send message",
)
event = await self.fetch_event(resp.event_id)

return Message(
room=self,
event=event,
client=self.client,
)
except Exception as e:
raise MatrixError(f"Failed to send message: {e}")
return Message(
room=self,
event=event,
client=self.client,
)

async def fetch_event(self, event_id: str) -> Event:
"""Fetch a Matrix event by its ID.
Expand All @@ -354,14 +354,11 @@ async def fetch_event(self, event_id: str) -> Event:
print(event.sender)
```
"""
try:
response = await self.client.room_get_event(
room_id=self.room_id,
event_id=event_id,
)
return response.event
except Exception as e:
raise MatrixError(f"Failed to get event: {e}")
response = await matrix_call(
self.client.room_get_event(room_id=self.room_id, event_id=event_id),
error_message="Failed to get event",
)
return response.event

async def fetch_message(self, event_id: str) -> Message:
"""Fetch a Message by its event ID.
Expand Down Expand Up @@ -393,14 +390,14 @@ async def on_message(room: Room, event: Event):
await room.mark_as_read(event.event_id)
```
"""
try:
await self.client.room_read_markers(
await matrix_call(
self.client.room_read_markers(
room_id=self.room_id,
fully_read_event=event_id,
read_event=event_id,
)
except Exception as e:
raise MatrixError(f"Failed to mark as read: {e}")
),
error_message="Failed to mark as read",
)

async def invite_user(self, user_id: str) -> None:
"""Invite a user to the room.
Expand All @@ -415,10 +412,10 @@ async def invite_user(self, user_id: str) -> None:
await room.invite_user("@alice:example.com")
```
"""
try:
await self.client.room_invite(room_id=self.room_id, user_id=user_id)
except Exception as e:
raise MatrixError(f"Failed to invite user: {e}")
await matrix_call(
self.client.room_invite(room_id=self.room_id, user_id=user_id),
error_message="Failed to invite user",
)

async def ban_user(self, user_id: str, reason: str | None = None) -> None:
"""Ban a user from the room.
Expand All @@ -436,12 +433,10 @@ async def ban_user(self, user_id: str, reason: str | None = None) -> None:
await room.ban_user("@spammer:example.com", reason="Spam and harassment")
```
"""
try:
await self.client.room_ban(
room_id=self.room_id, user_id=user_id, reason=reason
)
except Exception as e:
raise MatrixError(f"Failed to ban user: {e}")
await matrix_call(
self.client.room_ban(room_id=self.room_id, user_id=user_id, reason=reason),
error_message="Failed to ban user",
)

async def unban_user(self, user_id: str) -> None:
"""Unban a user from the room.
Expand All @@ -456,10 +451,10 @@ async def unban_user(self, user_id: str) -> None:
await room.unban_user("@alice:example.com")
```
"""
try:
await self.client.room_unban(room_id=self.room_id, user_id=user_id)
except Exception as e:
raise MatrixError(f"Failed to unban user: {e}")
await matrix_call(
self.client.room_unban(room_id=self.room_id, user_id=user_id),
error_message="Failed to unban user",
)

async def kick_user(self, user_id: str, reason: str | None = None) -> None:
"""Kick a user from the room.
Expand All @@ -478,12 +473,10 @@ async def kick_user(self, user_id: str, reason: str | None = None) -> None:
await room.kick_user("@troublemaker:example.com", reason="Violating room rules")
```
"""
try:
await self.client.room_kick(
room_id=self.room_id, user_id=user_id, reason=reason
)
except Exception as e:
raise MatrixError(f"Failed to kick user: {e}")
await matrix_call(
self.client.room_kick(room_id=self.room_id, user_id=user_id, reason=reason),
error_message="Failed to kick user",
)

async def get_members(self) -> list[str]:
"""Fetch the list of user IDs currently joined to the room.
Expand All @@ -498,8 +491,8 @@ async def get_members(self) -> list[str]:
print(f"{len(members)} members: {', '.join(members)}")
```
"""
try:
response = await self.client.joined_members(self.room_id)
return [member.user_id for member in response.members]
except Exception as e:
raise MatrixError(f"Failed to get members: {e}")
response = await matrix_call(
self.client.joined_members(self.room_id),
error_message="Failed to get members",
)
return [member.user_id for member in response.members]
43 changes: 43 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pytest
from nio import RoomSendResponse, RoomSendError

from matrix.errors import MatrixError
from matrix.api import matrix_call


@pytest.mark.asyncio
async def test_matrix_call_with_success__expect_response_returned():
async def call():
return RoomSendResponse(event_id="$event123", room_id="!room:example.com")

response = await matrix_call(call(), error_message="Failed to send message")

assert response.event_id == "$event123"


@pytest.mark.asyncio
async def test_matrix_call_with_transport_exception__expect_matrix_error():
async def call():
raise Exception("Network error")

with pytest.raises(MatrixError, match="Failed to send message: Network error"):
await matrix_call(call(), error_message="Failed to send message")


@pytest.mark.asyncio
async def test_matrix_call_with_error_response__expect_matrix_error():
async def call():
return RoomSendError("not allowed", "M_FORBIDDEN")

with pytest.raises(MatrixError, match="Failed to send message: .*M_FORBIDDEN"):
await matrix_call(call(), error_message="Failed to send message")


def test_matrix_call_requires_keyword_error_message__expect_type_error():
with pytest.raises(TypeError):
matrix_call(None, "Failed to send message")


def test_matrix_call_requires_positional_coro__expect_type_error():
with pytest.raises(TypeError):
matrix_call(coro=None, error_message="Failed to send message")
18 changes: 17 additions & 1 deletion tests/test_bot.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import pytest

from unittest.mock import AsyncMock, MagicMock, patch
from nio import MatrixRoom, RoomMessageText
from nio import MatrixRoom, RoomMessageText, LoginError

from matrix import Bot, Config, Extension, Room, Space
from matrix.errors import (
CheckError,
CommandNotFoundError,
AlreadyRegisteredError,
MatrixError,
)


Expand Down Expand Up @@ -555,6 +556,21 @@ async def mock_login(password):
bot._on_ready.assert_awaited_once()


@pytest.mark.asyncio
async def test_run_with_login_api_error__expect_matrix_error(bot):
bot._client.login = AsyncMock(
return_value=LoginError("bad credentials", "M_FORBIDDEN")
)
bot._client.sync_forever = AsyncMock()
bot._on_ready = AsyncMock()

with pytest.raises(MatrixError, match="Failed to log in"):
await bot.run()

bot._client.sync_forever.assert_not_called()
bot._on_ready.assert_not_called()


def test_start_handles_keyboard_interrupt(caplog):
bot = Bot()
bot._client = MagicMock()
Expand Down
Loading
Loading