From 01862479268b1f9c0522e08f257828da3bac6b3d Mon Sep 17 00:00:00 2001 From: mic1on Date: Thu, 12 Mar 2026 12:02:41 +0800 Subject: [PATCH 1/2] fix: make default notify instance and publisher state thread-safe - Replace global default instance with ContextVar to isolate per execution context - Add RLock to protect Publisher.channels and Publisher.retry_config - Snapshot state before publishing to avoid concurrent modifications - Add tests for thread/async task isolation and concurrent state mutations --- src/use_notify/decorator/core.py | 94 +++++++++++++--------- src/use_notify/notification.py | 74 +++++++++++------ tests/test_decorator.py | 69 ++++++++++++++++ tests/test_notification.py | 134 +++++++++++++++++++++++++++++++ 4 files changed, 308 insertions(+), 63 deletions(-) diff --git a/src/use_notify/decorator/core.py b/src/use_notify/decorator/core.py index 4ee3984..fa28848 100644 --- a/src/use_notify/decorator/core.py +++ b/src/use_notify/decorator/core.py @@ -4,6 +4,7 @@ """ import asyncio +from contextvars import ContextVar import functools import inspect import logging @@ -19,13 +20,16 @@ logger = logging.getLogger(__name__) -# 全局默认通知实例 -_default_notify_instance: Optional[Notify] = None +# 默认通知实例,按当前执行上下文隔离,避免线程/任务间相互污染 +_default_notify_instance_var: ContextVar[Optional[Notify]] = ContextVar( + "use_notify_default_instance", + default=None, +) RetriableExceptionsInput = Optional[Sequence[Type[BaseException]]] def set_default_notify_instance(notify_instance: Notify) -> None: - """设置全局默认通知实例 + """设置当前执行上下文的默认通知实例 Args: notify_instance: 要设置为默认的 Notify 实例 @@ -41,27 +45,25 @@ def set_default_notify_instance(notify_instance: Notify) -> None: def my_task(): return "任务完成" """ - global _default_notify_instance if not isinstance(notify_instance, Notify): raise NotifyConfigError("notify_instance 必须是 Notify 类的实例") - _default_notify_instance = notify_instance - logger.info("已设置全局默认通知实例") + _default_notify_instance_var.set(notify_instance) + logger.info("已设置默认通知实例") def get_default_notify_instance() -> Optional[Notify]: - """获取全局默认通知实例 + """获取当前执行上下文的默认通知实例 Returns: 当前的默认通知实例,如果未设置则返回 None """ - return _default_notify_instance + return _default_notify_instance_var.get() def clear_default_notify_instance() -> None: - """清除全局默认通知实例""" - global _default_notify_instance - _default_notify_instance = None - logger.info("已清除全局默认通知实例") + """清除当前执行上下文的默认通知实例""" + _default_notify_instance_var.set(None) + logger.info("已清除默认通知实例") class NotifyDecorator: @@ -90,27 +92,16 @@ def __init__( max_retries, retry_delay, retry_backoff, retriable_exceptions, ) - # 如果没有提供 notify_instance,尝试使用全局默认实例 - if notify_instance is None: - notify_instance = get_default_notify_instance() - if notify_instance is None: - notify_instance = Notify() - logger.warning("未提供 notify_instance 且未设置全局默认实例,创建了一个空的 Notify 实例。请确保添加通知渠道或设置默认实例。") - else: - logger.debug("使用全局默认通知实例") - - notify_instance = self._apply_retry_overrides( - notify_instance=notify_instance, - max_retries=max_retries, - retry_delay=retry_delay, - retry_backoff=retry_backoff, - retriable_exceptions=retriable_exceptions, - ) - self.notify_instance = notify_instance self.title = title self.notify_on_success = notify_on_success self.notify_on_error = notify_on_error + self.timeout = timeout + self.max_retries = max_retries + self.retry_delay = retry_delay + self.retry_backoff = retry_backoff + self.retriable_exceptions = retriable_exceptions + self._warned_missing_default_notify = False # 创建消息格式化器 self.formatter = MessageFormatter( @@ -120,12 +111,6 @@ def __init__( include_result=include_result ) - # 创建通知发送器 - self.sender = NotificationSender( - notify_instance=self.notify_instance, - timeout=timeout - ) - def __call__(self, func: Callable) -> Callable: """装饰器调用""" if inspect.iscoroutinefunction(func): @@ -222,7 +207,8 @@ def _send_success_notification(self, context: ExecutionContext) -> None: try: message = self.formatter.format_success_message(context) title = self.title or message["title"] - self.sender.send_notification(title, message["content"]) + sender = self._build_sender() + sender.send_notification(title, message["content"]) except Exception as e: logger.warning(f"发送成功通知失败: {e}") @@ -231,7 +217,8 @@ async def _send_success_notification_async(self, context: ExecutionContext) -> N try: message = self.formatter.format_success_message(context) title = self.title or message["title"] - await self.sender.send_notification_async(title, message["content"]) + sender = self._build_sender() + await sender.send_notification_async(title, message["content"]) except Exception as e: logger.warning(f"发送成功通知失败: {e}") @@ -240,7 +227,8 @@ def _send_error_notification(self, context: ExecutionContext) -> None: try: message = self.formatter.format_error_message(context) title = self.title or message["title"] - self.sender.send_notification(title, message["content"]) + sender = self._build_sender() + sender.send_notification(title, message["content"]) except Exception as e: logger.warning(f"发送错误通知失败: {e}") @@ -249,9 +237,37 @@ async def _send_error_notification_async(self, context: ExecutionContext) -> Non try: message = self.formatter.format_error_message(context) title = self.title or message["title"] - await self.sender.send_notification_async(title, message["content"]) + sender = self._build_sender() + await sender.send_notification_async(title, message["content"]) except Exception as e: logger.warning(f"发送错误通知失败: {e}") + + def _build_sender(self) -> NotificationSender: + notify_instance = self._resolve_notify_instance() + return NotificationSender(notify_instance=notify_instance, timeout=self.timeout) + + def _resolve_notify_instance(self) -> Notify: + notify_instance = self.notify_instance + + if notify_instance is None: + notify_instance = get_default_notify_instance() + if notify_instance is None: + notify_instance = Notify() + if not self._warned_missing_default_notify: + logger.warning( + "未提供 notify_instance 且当前执行上下文未设置默认实例,创建了一个空的 Notify 实例。请确保添加通知渠道或设置默认实例。" + ) + self._warned_missing_default_notify = True + else: + logger.debug("使用全局默认通知实例") + + return self._apply_retry_overrides( + notify_instance=notify_instance, + max_retries=self.max_retries, + retry_delay=self.retry_delay, + retry_backoff=self.retry_backoff, + retriable_exceptions=self.retriable_exceptions, + ) def _validate_config(self, *args) -> None: """验证配置参数""" diff --git a/src/use_notify/notification.py b/src/use_notify/notification.py index f457c51..724b977 100644 --- a/src/use_notify/notification.py +++ b/src/use_notify/notification.py @@ -4,6 +4,7 @@ import smtplib import time from dataclasses import dataclass +from threading import RLock from typing import List, Optional, Tuple, Type, TypeVar import httpx @@ -72,7 +73,8 @@ def __init__( ): if channels is None: channels = [] - self.channels = channels + self._state_lock = RLock() + self.channels = tuple(channels) self.retry_config = RetryConfig( max_retries=max_retries, retry_delay=retry_delay, @@ -87,8 +89,10 @@ def add(self, *channels): Args: *channels: Variable number of BaseChannel objects. """ - for channel in channels: - self.channels.append(channel) + if not channels: + return + with self._state_lock: + self.channels = self.channels + tuple(channels) def configure_retry( self: PublisherT, @@ -100,22 +104,25 @@ def configure_retry( """ Update retry policy for subsequent sends. """ - self.retry_config = RetryConfig( + retry_config = RetryConfig( max_retries=max_retries, retry_delay=retry_delay, retry_backoff=retry_backoff, retriable_exceptions=retriable_exceptions, ) + with self._state_lock: + self.retry_config = retry_config return self def publish(self, *args, **kwargs): """ Publish a notification to all channels. """ + channels, retry_config = self._snapshot_state() failures = [] - for channel in self.channels: + for channel in channels: try: - self._send_with_retry(channel, *args, **kwargs) + self._send_with_retry(channel, retry_config, *args, **kwargs) except Exception as error: failures.append((self._channel_name(channel), error)) @@ -126,20 +133,28 @@ async def publish_async(self, *args, **kwargs): """ Publish a notification asynchronously to all channels. """ - tasks = [self._send_with_retry_async(channel, *args, **kwargs) for channel in self.channels] + channels, retry_config = self._snapshot_state() + tasks = [ + self._send_with_retry_async(channel, retry_config, *args, **kwargs) + for channel in channels + ] results = await asyncio.gather(*tasks, return_exceptions=True) failures = [] - for channel, result in zip(self.channels, results): + for channel, result in zip(channels, results): if isinstance(result, Exception): failures.append((self._channel_name(channel), result)) if failures: self._raise_publish_error(failures) - def _send_with_retry(self, channel, *args, **kwargs): - max_attempts = self.retry_config.max_retries + 1 - delay = self.retry_config.retry_delay + def _snapshot_state(self): + with self._state_lock: + return self.channels, self.retry_config + + def _send_with_retry(self, channel, retry_config: RetryConfig, *args, **kwargs): + max_attempts = retry_config.max_retries + 1 + delay = retry_config.retry_delay for attempt in range(1, max_attempts + 1): try: @@ -149,7 +164,7 @@ def _send_with_retry(self, channel, *args, **kwargs): if attempt == max_attempts: raise - if not self._is_retriable_exception(error): + if not self._is_retriable_exception(error, retry_config): logger.debug( "Channel %s send failed with non-retriable %s: %s", self._channel_name(channel), @@ -158,14 +173,16 @@ def _send_with_retry(self, channel, *args, **kwargs): ) raise - self._log_retry(channel, attempt, error, delay) + self._log_retry(channel, attempt, error, delay, retry_config) if delay > 0: time.sleep(delay) - delay *= self.retry_config.retry_backoff + delay *= retry_config.retry_backoff - async def _send_with_retry_async(self, channel, *args, **kwargs): - max_attempts = self.retry_config.max_retries + 1 - delay = self.retry_config.retry_delay + async def _send_with_retry_async( + self, channel, retry_config: RetryConfig, *args, **kwargs + ): + max_attempts = retry_config.max_retries + 1 + delay = retry_config.retry_delay for attempt in range(1, max_attempts + 1): try: @@ -175,7 +192,7 @@ async def _send_with_retry_async(self, channel, *args, **kwargs): if attempt == max_attempts: raise - if not self._is_retriable_exception(error): + if not self._is_retriable_exception(error, retry_config): logger.debug( "Channel %s send failed with non-retriable %s: %s", self._channel_name(channel), @@ -184,21 +201,28 @@ async def _send_with_retry_async(self, channel, *args, **kwargs): ) raise - self._log_retry(channel, attempt, error, delay) + self._log_retry(channel, attempt, error, delay, retry_config) if delay > 0: await asyncio.sleep(delay) - delay *= self.retry_config.retry_backoff + delay *= retry_config.retry_backoff @staticmethod def _channel_name(channel) -> str: return channel.__class__.__name__ - def _log_retry(self, channel, attempt: int, error: Exception, delay: float): + def _log_retry( + self, + channel, + attempt: int, + error: Exception, + delay: float, + retry_config: RetryConfig, + ): logger.debug( "Channel %s send failed on attempt %s/%s with %s: %s. Retrying in %.2fs", self._channel_name(channel), attempt, - self.retry_config.max_retries + 1, + retry_config.max_retries + 1, error.__class__.__name__, error, delay, @@ -210,7 +234,9 @@ def _raise_publish_error(failures): raise failures[0][1] raise NotificationPublishError(failures) - def _is_retriable_exception(self, error: Exception) -> bool: + def _is_retriable_exception( + self, error: Exception, retry_config: RetryConfig + ) -> bool: if isinstance(error, httpx.HTTPStatusError): if error.response is None: return False @@ -226,7 +252,7 @@ def _is_retriable_exception(self, error: Exception) -> bool: if isinstance(error, smtplib.SMTPResponseException): return 400 <= error.smtp_code < 500 - return isinstance(error, self.retry_config.retriable_exceptions) + return isinstance(error, retry_config.retriable_exceptions) class Notify(Publisher): diff --git a/tests/test_decorator.py b/tests/test_decorator.py index 2b9d723..dfe311c 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -1,4 +1,6 @@ import asyncio +from concurrent.futures import ThreadPoolExecutor +import threading import pytest @@ -74,6 +76,18 @@ def task(): assert task() == "ok" assert len(channel.sync_messages) == 1 + def test_default_instance_is_resolved_at_call_time(self): + channel = RecordingChannel() + + @notify() + def task(): + return "ok" + + set_default_notify_instance(useNotify([channel])) + + assert task() == "ok" + assert len(channel.sync_messages) == 1 + def test_explicit_instance_overrides_default(self): default_channel = RecordingChannel() explicit_channel = RecordingChannel() @@ -117,3 +131,58 @@ def task(): assert task() == "business-result" assert len(channel.sync_messages) == 1 + + def test_default_instance_is_isolated_per_thread(self): + first_channel = RecordingChannel() + second_channel = RecordingChannel() + barrier = threading.Barrier(2) + + @notify() + def task(): + return "ok" + + def worker(channel): + set_default_notify_instance(useNotify([channel])) + barrier.wait() + try: + return task() + finally: + clear_default_notify_instance() + + with ThreadPoolExecutor(max_workers=2) as executor: + first_result = executor.submit(worker, first_channel) + second_result = executor.submit(worker, second_channel) + + assert first_result.result() == "ok" + assert second_result.result() == "ok" + assert len(first_channel.sync_messages) == 1 + assert len(second_channel.sync_messages) == 1 + + @pytest.mark.asyncio + async def test_default_instance_is_isolated_per_async_task(self): + first_channel = RecordingChannel() + second_channel = RecordingChannel() + + @notify(include_result=True) + async def task(label): + await asyncio.sleep(0) + return label + + async def worker(label, channel): + set_default_notify_instance(useNotify([channel])) + try: + return await task(label) + finally: + clear_default_notify_instance() + + first_result, second_result = await asyncio.gather( + worker("first", first_channel), + worker("second", second_channel), + ) + + assert first_result == "first" + assert second_result == "second" + assert len(first_channel.async_messages) == 1 + assert len(second_channel.async_messages) == 1 + assert "first" in first_channel.async_messages[0]["content"] + assert "second" in second_channel.async_messages[0]["content"] diff --git a/tests/test_notification.py b/tests/test_notification.py index eb92fdc..2f2555e 100644 --- a/tests/test_notification.py +++ b/tests/test_notification.py @@ -1,12 +1,40 @@ +import asyncio +import threading + import httpx import pytest +import use_notify.notification as notification_module from use_notify import NotificationPublishError, useNotify, useNotifyChannel from use_notify.notification import Publisher, RetryConfig from tests.helpers import RecordingChannel, make_http_status_error +class BlockingSyncChannel(RecordingChannel): + def __init__(self, started_event: threading.Event, release_event: threading.Event): + super().__init__() + self.started_event = started_event + self.release_event = release_event + + def send(self, content, title=None): + self.started_event.set() + assert self.release_event.wait(timeout=1) + super().send(content, title) + + +class BlockingAsyncChannel(RecordingChannel): + def __init__(self, started_event: asyncio.Event, release_event: asyncio.Event): + super().__init__() + self.started_event = started_event + self.release_event = release_event + + async def send_async(self, content, title=None): + self.started_event.set() + await asyncio.wait_for(self.release_event.wait(), timeout=1) + await super().send_async(content, title) + + def test_publisher_add_and_publish_across_channels(): first = RecordingChannel() second = RecordingChannel() @@ -82,6 +110,105 @@ def test_publisher_aggregates_failures_after_other_channels_continue(): assert len(error_info.value.failures) == 2 +def test_publisher_copies_initial_channel_collection(): + initial_channels = [RecordingChannel()] + publisher = Publisher(initial_channels) + + initial_channels.append(RecordingChannel()) + publisher.publish("hello") + + assert len(publisher.channels) == 1 + + +def test_publisher_add_does_not_affect_in_flight_sync_publish(): + started = threading.Event() + release = threading.Event() + first = BlockingSyncChannel(started, release) + added = RecordingChannel() + publisher = Publisher([first]) + errors = [] + + publish_thread = threading.Thread( + target=lambda: _publish_and_capture_error(publisher, errors, "hello") + ) + publish_thread.start() + + assert started.wait(timeout=1) + publisher.add(added) + release.set() + publish_thread.join(timeout=1) + + assert not publish_thread.is_alive() + assert not errors + assert len(first.sync_messages) == 1 + assert added.sync_messages == [] + + publisher.publish("later") + + assert len(added.sync_messages) == 1 + + +@pytest.mark.asyncio +async def test_publisher_add_does_not_affect_in_flight_async_publish(): + started = asyncio.Event() + release = asyncio.Event() + first = BlockingAsyncChannel(started, release) + added = RecordingChannel() + publisher = Publisher([first]) + + publish_task = asyncio.create_task(publisher.publish_async("hello")) + + await asyncio.wait_for(started.wait(), timeout=1) + publisher.add(added) + release.set() + await asyncio.wait_for(publish_task, timeout=1) + + assert len(first.async_messages) == 1 + assert added.async_messages == [] + + await publisher.publish_async("later") + + assert len(added.async_messages) == 1 + + +def test_configure_retry_does_not_affect_in_flight_publish(monkeypatch): + channel = RecordingChannel( + sync_failures=[RuntimeError("temporary"), RuntimeError("temporary")] + ) + publisher = Publisher( + [channel], + max_retries=2, + retry_delay=0.01, + retriable_exceptions=(RuntimeError,), + ) + sleep_started = threading.Event() + release_sleep = threading.Event() + errors = [] + + def controlled_sleep(_delay): + sleep_started.set() + assert release_sleep.wait(timeout=1) + + monkeypatch.setattr(notification_module.time, "sleep", controlled_sleep) + + publish_thread = threading.Thread( + target=lambda: _publish_and_capture_error(publisher, errors, "hello") + ) + publish_thread.start() + + assert sleep_started.wait(timeout=1) + publisher.configure_retry( + max_retries=0, + retriable_exceptions=(TimeoutError,), + ) + release_sleep.set() + publish_thread.join(timeout=1) + + assert not publish_thread.is_alive() + assert not errors + assert len(channel.sync_messages) == 3 + + def test_configure_retry_returns_self_and_updates_policy(): publisher = Publisher() @@ -122,3 +249,10 @@ def test_notify_from_settings_builds_case_insensitive_channels(): def test_notify_from_settings_rejects_unknown_channel(): with pytest.raises(ValueError, match="Unknown channel"): useNotify.from_settings({"unknown": {"token": "x"}}) + + +def _publish_and_capture_error(publisher, errors, content): + try: + publisher.publish(content) + except Exception as error: # pragma: no cover - exercised via assertions + errors.append(error) From dd4768bc0ee2e13899976d1fd650bb80dd0fa5ac Mon Sep 17 00:00:00 2001 From: mic1on Date: Thu, 12 Mar 2026 17:52:19 +0800 Subject: [PATCH 2/2] fix: make sync notification timeout work in async event loops - Use ThreadPoolExecutor to execute sync notifications with timeout - Shutdown executor with wait=False to avoid blocking main thread - Add tests for sync timeout in both sync and async contexts - Add test for async timeout behavior - Remove unused NotifySendError import - Fix code style issues (blank lines, imports) --- src/use_notify/decorator/sender.py | 60 +++++++++---------- tests/test_decorator.py | 94 ++++++++++++++++++++++++++++++ 2 files changed, 120 insertions(+), 34 deletions(-) diff --git a/src/use_notify/decorator/sender.py b/src/use_notify/decorator/sender.py index aff24a9..43c2148 100644 --- a/src/use_notify/decorator/sender.py +++ b/src/use_notify/decorator/sender.py @@ -5,10 +5,10 @@ import asyncio import logging +from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError from typing import Optional from ..notification import Notify -from .exceptions import NotifySendError logger = logging.getLogger(__name__) @@ -16,7 +16,7 @@ class NotificationSender: """通知发送器""" - + def __init__( self, notify_instance: Notify, @@ -24,42 +24,34 @@ def __init__( ): self.notify_instance = notify_instance self.timeout = timeout - + def send_notification(self, title: str, content: str) -> None: """发送同步通知""" try: if self.timeout: - # 对于同步调用,我们使用 asyncio.wait_for 来实现超时 - # 但这需要在异步上下文中运行,所以我们创建一个新的事件循环 + # 使用线程池执行同步调用,实现超时控制 + # 无论是否在异步事件循环中,都能正确应用超时 + executor = ThreadPoolExecutor(max_workers=1) try: - loop = asyncio.get_event_loop() - if loop.is_running(): - # 如果已经在事件循环中,直接调用 - self.notify_instance.publish(title=title, content=content) - else: - # 如果不在事件循环中,使用 run_until_complete - loop.run_until_complete( - asyncio.wait_for( - self._send_async_internal(title, content), - timeout=self.timeout - ) - ) - except RuntimeError: - # 如果没有事件循环,创建一个新的 - asyncio.run( - asyncio.wait_for( - self._send_async_internal(title, content), - timeout=self.timeout - ) + future = executor.submit( + self.notify_instance.publish, title=title, content=content ) + future.result(timeout=self.timeout) + finally: + # 使用 shutdown(wait=False) 立即关闭线程池,不等待线程完成 + executor.shutdown(wait=False) else: self.notify_instance.publish(title=title, content=content) - + logger.info(f"通知发送成功: {title}") - + + except FutureTimeoutError: + error_msg = f"通知发送超时({self.timeout}秒)" + logger.warning(error_msg) + self._handle_send_error(asyncio.TimeoutError(error_msg)) except Exception as error: self._handle_send_error(error) - + async def send_notification_async(self, title: str, content: str) -> None: """发送异步通知""" try: @@ -70,24 +62,24 @@ async def send_notification_async(self, title: str, content: str) -> None: ) else: await self._send_async_internal(title, content) - + logger.info(f"异步通知发送成功: {title}") - + except Exception as error: self._handle_send_error(error) - + async def _send_async_internal(self, title: str, content: str) -> None: """内部异步发送方法""" await self.notify_instance.publish_async(title=title, content=content) - + def _handle_send_error(self, error: Exception) -> None: """处理发送错误""" error_msg = f"通知发送失败: {str(error)}" logger.warning(error_msg) - + # 记录详细错误信息但不抛出异常,避免影响原函数执行 logger.debug(f"通知发送错误详情: {error}", exc_info=True) - + # 可以选择是否抛出异常,根据需求决定 # 默认情况下不抛出,只记录日志 - # raise NotifySendError(error_msg) from error \ No newline at end of file + # raise NotifySendError(error_msg) from error diff --git a/tests/test_decorator.py b/tests/test_decorator.py index dfe311c..8f29cb7 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -186,3 +186,97 @@ async def worker(label, channel): assert len(second_channel.async_messages) == 1 assert "first" in first_channel.async_messages[0]["content"] assert "second" in second_channel.async_messages[0]["content"] + + def test_sync_timeout_does_not_block_main_thread(self): + """测试同步超时不阻塞主线程""" + import time + + class SlowChannel(RecordingChannel): + def send(self, content, title=None): + time.sleep(2) + super().send(content, title) + + channel = SlowChannel() + notify_instance = useNotify([channel]) + + @notify(notify_instance=notify_instance, timeout=0.1) + def task(): + return "ok" + + start = time.time() + result = task() + elapsed = time.time() - start + + # 函数应该立即返回,不等待2秒 + assert result == "ok" + assert elapsed < 0.5, f"函数执行时间 {elapsed:.2f}s 超过预期" + # 超时错误应该被记录(不抛出,避免影响原函数执行) + + @pytest.mark.asyncio + async def test_sync_timeout_in_async_event_loop(self): + """测试同步超时在异步事件循环中也能正确应用""" + import time + + class SlowChannel(RecordingChannel): + def send(self, content, title=None): + time.sleep(2) + super().send(content, title) + + channel = SlowChannel() + notify_instance = useNotify([channel]) + + @notify(notify_instance=notify_instance, timeout=0.1) + def sync_task(): + return "ok" + + import asyncio + start = asyncio.get_event_loop().time() + + # 在异步上下文中调用同步装饰器 + result = sync_task() + + elapsed = asyncio.get_event_loop().time() - start + + # 函数应该立即返回,不等待2秒 + assert result == "ok" + assert elapsed < 0.5, f"函数执行时间 {elapsed:.2f}s 超过预期" + + def test_sync_send_without_timeout(self): + """测试不设置超时时正常发送""" + channel = RecordingChannel() + notify_instance = useNotify([channel]) + + @notify(notify_instance=notify_instance) + def task(): + return "ok" + + result = task() + assert result == "ok" + assert len(channel.sync_messages) == 1 + + @pytest.mark.asyncio + async def test_async_timeout_works(self): + """测试异步超时正常工作""" + class SlowChannel(RecordingChannel): + async def send_async(self, content, title=None): + await asyncio.sleep(2) + await super().send_async(content, title) + + channel = SlowChannel() + notify_instance = useNotify([channel]) + + @notify(notify_instance=notify_instance, timeout=0.1, notify_on_error=False) + async def task(): + return "ok" + + start = asyncio.get_event_loop().time() + result = await task() + elapsed = asyncio.get_event_loop().time() - start + + # 函数应该立即返回,不等待2秒 + assert result == "ok" + assert elapsed < 0.5, f"函数执行时间 {elapsed:.2f}s 超过预期" + # 超时应该生效,通知发送失败 + assert len(channel.async_messages) == 0 + +