Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions matrix/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ async def on_error(self, error: Exception) -> None:
self.log.exception("Unhandled error: '%s'", error)

async def _on_error(self, error: Exception) -> None:
if handler := self._error_handlers.get(type(error)):
if handler := self.resolve_handler(self._error_handlers, error):
await handler(error)
return

Expand All @@ -259,19 +259,23 @@ async def on_command_error(self, _ctx: Context, error: Exception) -> None:
"""Override this in a subclass."""
self.log.exception("Unhandled error: '%s'", error)

async def _on_command_error(self, ctx: Context, error: Exception) -> None:
async def _on_command_error(self, ctx: Context, error: Exception) -> bool:
"""
Handles errors raised during command invocation.

This method is called automatically when a command error occurs.
If a specific error handler is registered for the type of the
exception, it will be invoked with the current context and error.

Returns True if a specific handler was found and invoked, False if
it fell through to the default dispatch/log path.
"""
if handler := self._command_error_handlers.get(type(error)):
if handler := self.resolve_handler(self._command_error_handlers, error):
await handler(ctx, error)
return
return True

await self._dispatch("on_command_error", ctx, error)
return False

# ENTRYPOINT

Expand Down
11 changes: 9 additions & 2 deletions matrix/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,12 +285,19 @@ async def on_error(self, ctx: "Context", error: Exception) -> None:
"""
Executes the registered error handler if present.
"""
handler = None

if handler := self._error_handlers.get(type(error)):
for cls in inspect.getmro(type(error)):
if cls in self._error_handlers:
handler = self._error_handlers[cls]
break

if handler:
await handler(ctx, error)
return

await ctx.bot.on_command_error(ctx, error)
if await ctx.bot._on_command_error(ctx, error):
return

if self._on_error:
await self._on_error(ctx, error)
Expand Down
10 changes: 10 additions & 0 deletions matrix/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,3 +458,13 @@ def _register_command_error(
if not exception:
exception = Exception
self._command_error_handlers[exception] = func

@staticmethod
def resolve_handler(
handlers: Dict[type[Exception], F], error: Exception
) -> Optional[F]:
"""Look up the handler registered for the error's type or nearest base class."""
for cls in inspect.getmro(type(error)):
if cls in handlers:
return handlers[cls]
return None
94 changes: 93 additions & 1 deletion tests/test_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from unittest.mock import AsyncMock, MagicMock, patch
from nio import MatrixRoom, RoomMessageText, LoginError

from matrix import Bot, Config, Extension, Room, Space
from matrix import Bot, Config, Context, Extension, Room, Space
from matrix.errors import (
CheckError,
CommandNotFoundError,
Expand Down Expand Up @@ -290,6 +290,46 @@ async def custom_error_handler(e):
assert called, "Fallback error handler was not called"


@pytest.mark.asyncio
async def test_on_error__with_subclass_error__expect_fallback_handler_called(bot):
called_with = None

@bot.error()
async def custom_error_handler(e):
nonlocal called_with
called_with = e

error = ValueError("subclass of Exception")
await bot._on_error(error)

assert called_with is error, "Fallback handler should catch Exception subclasses"


@pytest.mark.asyncio
async def test_on_error__with_specific_and_fallback_handlers__expect_specific_handler_called(
bot,
):
fallback_called = False
specific_called = False

@bot.error()
async def fallback_handler(e):
nonlocal fallback_called
fallback_called = True

@bot.error(ValueError)
async def specific_handler(e):
nonlocal specific_called
specific_called = True

await bot._on_error(ValueError("test error"))

assert specific_called, "Specific handler should be used when available"
assert (
not fallback_called
), "Fallback handler should not run when a specific one matches"


@pytest.mark.asyncio
async def test_on_error_logs_when_no_handler(bot):
error = Exception("test")
Expand All @@ -298,6 +338,24 @@ async def test_on_error_logs_when_no_handler(bot):
bot.log.exception.assert_called_once_with("Unhandled error: '%s'", error)


@pytest.mark.asyncio
async def test_on_command_error__with_matching_handler__expect_returns_true(bot):
@bot.error(ValueError, context=True)
async def handler(ctx, error):
pass

result = await bot._on_command_error(MagicMock(), ValueError("boom"))

assert result is True


@pytest.mark.asyncio
async def test_on_command_error__with_no_matching_handler__expect_returns_false(bot):
result = await bot._on_command_error(MagicMock(), ValueError("boom"))

assert result is False


@pytest.mark.asyncio
async def test_process_commands_executes_command(bot, event):
called = False
Expand Down Expand Up @@ -379,6 +437,40 @@ async def global_check(ctx):
assert isinstance(bot._on_command_error.call_args[0][1], CheckError)


@pytest.mark.asyncio
async def test_command_error_handler__with_error_raised_in_command_body__expect_handler_called(
bot,
):
handled = None

@bot.error(ValueError, context=True)
async def on_value_error(ctx, error):
nonlocal handled
handled = error

@bot.command()
async def boom(ctx):
raise ValueError("kaboom")

event = RoomMessageText.from_dict(
{
"content": {"body": "!boom", "msgtype": "m.text"},
"event_id": "$ev3",
"origin_server_ts": 1234567890,
"sender": "@user:matrix.org",
"type": "m.room.message",
}
)

room = MatrixRoom("!roomid", "alias")

with patch.object(Context, "send_help", new_callable=AsyncMock) as mock_send_help:
await bot._process_commands(room, event)

assert isinstance(handled, ValueError)
mock_send_help.assert_not_called()


@pytest.mark.asyncio
async def test_command_executes(bot):
called = False
Expand Down
62 changes: 61 additions & 1 deletion tests/test_command.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import inspect

from unittest.mock import MagicMock
from unittest.mock import AsyncMock, MagicMock
from matrix.errors import MissingArgumentError
from matrix.command import Command

Expand All @@ -10,6 +10,10 @@ class DummyBot:
async def on_command_error(self, _ctx, _error):
return None

async def _on_command_error(self, ctx, error):
await self.on_command_error(ctx, error)
return False


class DummyContext:
def __init__(self, args=None):
Expand Down Expand Up @@ -145,6 +149,62 @@ async def handler(_ctx, _error):
assert called


@pytest.mark.asyncio
async def test_error_handler__with_exception_subclass__expect_handler_called():
class BaseCustomError(Exception):
pass

class SubCustomError(BaseCustomError):
pass

async def failing_command(ctx):
raise SubCustomError("boom")

cmd = Command(failing_command)
ctx = DummyContext(args=[])
handled = None

@cmd.error(BaseCustomError)
async def handler(_ctx, error):
nonlocal handled
handled = error

await cmd(ctx)
assert isinstance(handled, SubCustomError)


@pytest.mark.asyncio
async def test_on_error__with_bot_handler_matched__expect_no_fallback_help():
async def failing_command(ctx):
raise ValueError("boom")

cmd = Command(failing_command)
ctx = DummyContext(args=[])
ctx.bot._on_command_error = AsyncMock(return_value=True)
ctx.send_help = AsyncMock()

await cmd(ctx)

ctx.send_help.assert_not_called()
ctx.logger.exception.assert_not_called()


@pytest.mark.asyncio
async def test_on_error__with_no_bot_handler_matched__expect_fallback_help():
async def failing_command(ctx):
raise ValueError("boom")

cmd = Command(failing_command)
ctx = DummyContext(args=[])
ctx.bot._on_command_error = AsyncMock(return_value=False)
ctx.send_help = AsyncMock()

await cmd(ctx)

ctx.send_help.assert_awaited_once()
ctx.logger.exception.assert_called_once()


@pytest.mark.asyncio
async def test_before_and_after_invoke():
call_order = []
Expand Down
Loading