diff --git a/openhands/automation/scheduler.py b/openhands/automation/scheduler.py index ac1d5c0..457f37b 100644 --- a/openhands/automation/scheduler.py +++ b/openhands/automation/scheduler.py @@ -12,13 +12,14 @@ import logging from datetime import datetime, timedelta +from pydantic import ValidationError from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from openhands.automation.db import using_sqlite from openhands.automation.models import Automation, AutomationRun -from openhands.automation.utils import is_automation_due, utcnow -from openhands.automation.utils.run import create_pending_run +from openhands.automation.schemas import TriggerAdapter +from openhands.automation.utils import utcnow logger = logging.getLogger("automation.scheduler") @@ -74,6 +75,60 @@ async def _fetch_enabled_automations( return list(result.scalars().all()) +async def _create_pending_runs( + session: AsyncSession, + automations: list[Automation], + now: datetime, +) -> list[AutomationRun]: + """Ask each automation's trigger to create a PENDING run if it's due. + + Each automation's ``trigger`` JSON is parsed into a typed model + (``CronTrigger``/``EventTrigger``/``GithubTrigger``) and its + :meth:`_TriggerBase.create_pending_run` is awaited. A return value of + ``None`` means "not due"; any other return is appended to the result. + + Calls run **sequentially** because they share a single + :class:`AsyncSession` (SQLAlchemy async sessions are not safe for + concurrent use). Triggers that need to fan out external I/O — e.g. + :class:`~openhands.automation.schemas.GithubTrigger` polling multiple + repos — do so internally. + + Per-trigger exceptions are logged and skipped so one bad trigger cannot + starve the rest of the batch. + """ + created: list[AutomationRun] = [] + for automation in automations: + try: + trigger = TriggerAdapter.validate_python(automation.trigger) + except ValidationError: + logger.exception("Invalid trigger config for automation %s", automation.id) + continue + + try: + run = await trigger.create_pending_run(session, automation, now) + except Exception: + logger.exception( + "Trigger %s raised while processing automation %s", + type(trigger).__name__, + automation.id, + ) + continue + + if run is None: + continue + + created.append(run) + logger.info( + "Created pending run: run_id=%s automation_id=%s name=%s trigger_type=%s", + run.id, + automation.id, + automation.name, + automation.trigger.get("type"), + ) + + return created + + async def poll_and_schedule( session_factory: async_sessionmaker[AsyncSession], batch_size: int = DEFAULT_BATCH_SIZE, @@ -117,25 +172,7 @@ async def poll_and_schedule( for automation in automations: automation.last_polled_at = now - due_automations = [a for a in automations if is_automation_due(a, now)] - - for automation in due_automations: - try: - run = await create_pending_run(session, automation) - created_runs.append(run) - logger.info( - "Created pending run: run_id=%s automation_id=%s " - "name=%s schedule=%s", - run.id, - automation.id, - automation.name, - automation.trigger.get("schedule"), - ) - except Exception: - logger.exception( - "Failed to create run for automation %s", - automation.id, - ) + created_runs.extend(await _create_pending_runs(session, automations, now)) # Always commit to release row locks from FOR UPDATE SKIP LOCKED, # even if no runs were created diff --git a/openhands/automation/schemas.py b/openhands/automation/schemas.py index cf7d6db..48b0a89 100644 --- a/openhands/automation/schemas.py +++ b/openhands/automation/schemas.py @@ -1,17 +1,49 @@ """Pydantic request/response schemas for the API.""" +from __future__ import annotations + +import logging import re import uuid from datetime import datetime from enum import StrEnum -from typing import Annotated, Literal +from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal +from zoneinfo import ZoneInfo +import httpx +import jmespath from croniter import croniter -from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, field_validator +from jmespath.exceptions import JMESPathError +from pydantic import ( + BaseModel, + ConfigDict, + Discriminator, + Field, + SecretStr, + Tag, + TypeAdapter, + field_serializer, + field_validator, +) from openhands.automation.config import get_config +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncSession + + from openhands.automation.models import Automation, AutomationRun + + +logger = logging.getLogger(__name__) + +_GITHUB_REPO_RE = re.compile(r"^[A-Za-z0-9_.\-]+/[A-Za-z0-9_.\-]+$") + +# Slack channel IDs are uppercase alphanumeric (e.g. 'C0123ABC', 'D…', 'G…'). +# Channel *names* aren't supported here because they're not stable across renames. +_SLACK_CHANNEL_ID_RE = re.compile(r"^[A-Z][A-Z0-9]+$") + + # Allowed URI schemes for tarball_path (includes internal upload scheme) _TARBALL_SCHEME_RE = re.compile(r"^(s3|gs|https?|oh-internal)://") @@ -37,11 +69,50 @@ def _validate_timeout(v: int | None) -> int | None: return v -class CronTrigger(BaseModel): - """Cron-based trigger configuration.""" +class _TriggerBase(BaseModel): + """Common base for all trigger configurations. + + Subclasses implement :meth:`create_pending_run` to decide — once per + scheduler poll cycle — whether the given automation should fire right now, + and if so, what (optional) event payload to attach to the resulting + ``AutomationRun``. + + Concurrency note: the scheduler invokes this method **sequentially** for + each automation in a batch because they share a single + :class:`~sqlalchemy.ext.asyncio.AsyncSession`, which is not safe for + concurrent use. Triggers should still parallelize their *own* external + I/O (e.g. by using :func:`openhands.automation.utils.async_utils.wait_all` + to fan out HTTP calls across multiple resources). + """ model_config = ConfigDict(extra="forbid") + type: str + + async def create_pending_run( + self, + session: AsyncSession, + automation: Automation, + now: datetime | None = None, + ) -> AutomationRun | None: + """Return a PENDING ``AutomationRun`` if the trigger is due, else None. + + Implementations that decide to fire MUST call + :func:`openhands.automation.utils.run.create_pending_run` (which + bumps ``last_triggered_at``/``last_polled_at`` and appends the run + to ``session``) and may then mutate the returned run — for example + to populate ``event_payload`` with the data that caused the fire. + + Exceptions raised here are logged and treated as "not due" by the + scheduler so that one broken trigger cannot starve the rest of the + batch. + """ + raise NotImplementedError + + +class CronTrigger(_TriggerBase): + """Cron-based trigger configuration.""" + type: Literal["cron"] = "cron" schedule: str = Field(..., description="Cron expression, e.g. '0 9 * * 5'") timezone: str = Field(default="UTC", description="IANA timezone name") @@ -53,8 +124,26 @@ def validate_cron_schedule(cls, v: str) -> str: raise ValueError(f"Invalid cron expression: {v}") return v + async def create_pending_run( + self, + session: AsyncSession, + automation: Automation, + now: datetime | None = None, + ) -> AutomationRun | None: + """Fire if the cron's most recent slot has passed since last trigger.""" + from openhands.automation.utils.cron import ( + is_automation_due as _is_due_cron, + ) + from openhands.automation.utils.run import ( + create_pending_run as _create_run_util, + ) + + if not _is_due_cron(automation, now): + return None + return await _create_run_util(session, automation) + -class EventTrigger(BaseModel): +class EventTrigger(_TriggerBase): """ Event-based trigger configuration. @@ -126,8 +215,6 @@ class EventTrigger(BaseModel): ``` """ - model_config = ConfigDict(extra="forbid") - type: Literal["event"] = "event" source: str = Field( ..., @@ -172,19 +259,536 @@ def event_patterns(self) -> list[str]: return [self.on] return self.on + async def create_pending_run( + self, + session: AsyncSession, # noqa: ARG002 + automation: Automation, # noqa: ARG002 + now: datetime | None = None, # noqa: ARG002 + ) -> AutomationRun | None: + """Event triggers are fired by the webhook router, never by polling.""" + return None + + +class _PollingTriggerBase(_TriggerBase): + """Base for triggers that poll external services for new items. + + Concrete subclasses (e.g. :class:`GithubTrigger`, :class:`SlackTrigger`) + own three pieces of behaviour: + + 1. **Authentication / client construction** — :meth:`_build_client`. + 2. **Fan-out fetching** — :meth:`_fetch_all_new` returns the flat list of + items (already filtered against the cutoff and :attr:`event_filter`) + collected from every configured resource, typically running per-resource + calls concurrently via :func:`wait_all`. + 3. **Payload labelling** — :attr:`_PAYLOAD_SOURCE` and :attr:`_PAYLOAD_KEY` + (e.g. ``"github_trigger"`` / ``"events"``). + + The base class itself handles: + + - The ``enabled`` / ``deleted_at`` short-circuit. + - Computing the cutoff (``last_triggered_at`` or ``created_at``). + - Calling :func:`openhands.automation.utils.run.create_pending_run` once + items are gathered and attaching them to ``run.event_payload``. + - Validating and compiling the optional JMESPath ``event_filter``. + """ + + # Pydantic config. ``arbitrary_types_allowed`` lets us cache a compiled + # JMESPath parser object on ``self.__dict__`` without Pydantic objecting. + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + + # Subclass-supplied metadata for the run's ``event_payload`` envelope. + _PAYLOAD_SOURCE: ClassVar[str] = "" + _PAYLOAD_KEY: ClassVar[str] = "items" + + event_filter: str | None = Field( + default=None, + description=( + "Optional JMESPath expression evaluated against each polled item " + "(see https://jmespath.org/). An item matches when the expression " + "returns a truthy value; non-matching items are dropped. The " + "trigger fires only if at least one item matches." + ), + ) + + @field_validator("event_filter") + @classmethod + def validate_event_filter(cls, v: str | None) -> str | None: + """Parse the JMESPath at validation time so bad syntax fails fast. + + The compiled expression itself is *not* stored on the field (so the + JSON round-trip stays clean) — it's cached lazily on first use via + :attr:`_compiled_event_filter`. + """ + if v is None: + return v + expr = v.strip() + if not expr: + return None + try: + jmespath.compile(expr) + except JMESPathError as e: + raise ValueError(f"Invalid JMESPath expression: {e}") from e + return expr + + @property + def _compiled_event_filter(self) -> jmespath.parser.ParsedResult | None: + """Compile and cache the JMESPath expression on first use.""" + if self.event_filter is None: + return None + cached = self.__dict__.get("_compiled_filter") + if cached is None: + # Already validated in ``validate_event_filter`` — won't raise. + cached = jmespath.compile(self.event_filter) + self.__dict__["_compiled_filter"] = cached + return cached + + def _item_matches_filter(self, item: dict[str, Any], context: str) -> bool: + """Return True if ``item`` passes ``event_filter`` (or no filter). + + ``context`` is a human-readable label (repo / channel id) used only + for logging if the JMESPath evaluator throws. + """ + compiled = self._compiled_event_filter + if compiled is None: + return True + try: + return bool(compiled.search(item)) + except JMESPathError as e: + logger.warning("JMESPath evaluation failed for item in %s: %s", context, e) + return False + + def _build_client(self) -> httpx.AsyncClient: + """Construct an authenticated HTTP client for the upstream service.""" + raise NotImplementedError + + async def _fetch_all_new( + self, + client: httpx.AsyncClient, + cutoff: datetime, + ) -> list[dict[str, Any]]: + """Return all items newer than ``cutoff`` from every configured resource. + + Implementations are expected to fan out across their resources + concurrently (typically via + :func:`openhands.automation.utils.async_utils.wait_all`) and apply + :meth:`_item_matches_filter` per item. + """ + raise NotImplementedError + + async def create_pending_run( + self, + session: AsyncSession, + automation: Automation, + now: datetime | None = None, # noqa: ARG002 + ) -> AutomationRun | None: + """Fire if any configured resource yields a matching new item. + + On fire, attaches the gathered items to ``run.event_payload`` as:: + + {"source": , + : [...]} + """ + # Deferred import avoids a circular dependency at module load. + from openhands.automation.utils.run import ( + create_pending_run as _create_run_util, + ) + + if not automation.enabled or automation.deleted_at is not None: + return None + + cutoff = automation.last_triggered_at or automation.created_at + if cutoff is None: + return None + if cutoff.tzinfo is None: + cutoff = cutoff.replace(tzinfo=ZoneInfo("UTC")) + + async with self._build_client() as client: + items = await self._fetch_all_new(client, cutoff) + + if not items: + return None + + run = await _create_run_util(session, automation) + run.event_payload = { + "source": self._PAYLOAD_SOURCE, + self._PAYLOAD_KEY: items, + } + return run + + +class GithubTrigger(_PollingTriggerBase): + """Poll-based trigger that fires when new events appear on GitHub repos. + + On each scheduler poll, the trigger queries the + ``/repos/{owner}/{repo}/events`` endpoint for each configured repository + **concurrently** and collects any event created after the automation's + last fire time (or its ``created_at`` for the very first poll, mirroring + the no-backfill semantics of :class:`CronTrigger`). + + If any matching events are found, :meth:`create_pending_run` creates a + PENDING ``AutomationRun`` and stores the collected events on + ``run.event_payload`` so the run's entrypoint can react to them; otherwise + it returns ``None``. + + Optionally narrow which events fire the trigger using + :attr:`event_filter` — a `JMESPath `_ expression + evaluated against each event. An event is kept when the expression + returns a truthy value. Examples:: + + type == 'PushEvent' + type == 'PullRequestEvent' && payload.action == 'opened' + type == 'PushEvent' && payload.ref == 'refs/heads/main' + + When ``event_filter`` is omitted, every event newer than the cutoff + fires the trigger. + """ + + _PAYLOAD_SOURCE: ClassVar[str] = "github_trigger" + _PAYLOAD_KEY: ClassVar[str] = "events" + + type: Literal["github"] = "github" + github_access_token: SecretStr = Field( + ..., + description=( + "GitHub Personal Access Token used to authenticate against the " + "REST API. Authenticated requests have a 5000/hour rate limit " + "compared to 60/hour unauthenticated." + ), + ) + repositories: list[str] = Field( + ..., + min_length=1, + description=( + "Repositories to poll, each as 'owner/name' " + "(e.g. 'All-Hands-AI/OpenHands')." + ), + ) + + @field_validator("repositories") + @classmethod + def validate_repositories(cls, v: list[str]) -> list[str]: + cleaned: list[str] = [] + for repo in v: + repo = repo.strip() + if not _GITHUB_REPO_RE.match(repo): + raise ValueError(f"Invalid repository {repo!r}: expected 'owner/name'") + cleaned.append(repo) + return cleaned + + @field_serializer("github_access_token", when_used="always") + def _serialize_token(self, v: SecretStr) -> str: + """Emit the raw secret so the trigger round-trips through the JSON column. + + The token must be stored in plain text because the scheduler needs + to read it back later to authenticate against GitHub. ``SecretStr`` + is still useful at the application layer: it guards against + accidental logging via ``repr()`` and string interpolation. Treat + the on-disk JSON as sensitive (same trust level as the rest of the + automation config) — anyone with read access to the database can + see it. + """ + return v.get_secret_value() + + def _build_client(self) -> httpx.AsyncClient: + """Construct an authenticated GitHub REST client.""" + return httpx.AsyncClient( + base_url="https://api.github.com", + headers={ + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", + "User-Agent": "openhands-automation", + "Authorization": ( + f"Bearer {self.github_access_token.get_secret_value()}" + ), + }, + timeout=30.0, + ) + + async def _fetch_new_events( + self, + client: httpx.AsyncClient, + repo: str, + cutoff: datetime, + ) -> list[dict[str, Any]]: + """Return all matching events for ``repo`` newer than ``cutoff``. + + Errors (HTTP/JSON/non-200) are logged and treated as "no new events" + so a single bad repo doesn't take down the whole trigger. + """ + try: + resp = await client.get( + f"/repos/{repo}/events", + params={"per_page": 30, "page": 1}, + ) + except httpx.HTTPError as e: + logger.warning("GitHub poll failed for %s: %s", repo, e) + return [] + + if resp.status_code != 200: + logger.warning( + "GitHub poll for %s returned status %s", + repo, + resp.status_code, + ) + return [] + + try: + events = resp.json() + except ValueError: + logger.warning("GitHub poll for %s returned non-JSON body", repo) + return [] + if not isinstance(events, list): + return [] + + new_events: list[dict[str, Any]] = [] + for ev in events: + if not isinstance(ev, dict): + continue + if not self._item_matches_filter(ev, repo): + continue + created_raw = ev.get("created_at") + if not isinstance(created_raw, str): + continue + try: + created_at = datetime.fromisoformat(created_raw.replace("Z", "+00:00")) + except ValueError: + continue + if created_at.tzinfo is None: + created_at = created_at.replace(tzinfo=ZoneInfo("UTC")) + if created_at > cutoff: + # Tag with the source repo so downstream code knows the origin. + tagged = dict(ev) + tagged.setdefault("_repository", repo) + new_events.append(tagged) + return new_events + + async def _fetch_all_new( + self, + client: httpx.AsyncClient, + cutoff: datetime, + ) -> list[dict[str, Any]]: + from openhands.automation.utils.async_utils import wait_all + + per_repo: list[list[dict[str, Any]]] = await wait_all( + [ + self._fetch_new_events(client, repo, cutoff) + for repo in self.repositories + ], + timeout=None, + ) + return [ev for batch in per_repo for ev in batch] + + +class SlackTrigger(_PollingTriggerBase): + """Poll-based trigger that fires when new messages appear in Slack channels. + + On each scheduler poll, the trigger calls + ``https://slack.com/api/conversations.history`` for each configured + channel **concurrently**, using ``oldest=`` as the + high-water mark (derived from the automation's last fire time, or + ``created_at`` for the very first poll). Messages with a Slack ``ts`` + strictly greater than the cutoff are kept. + + Optionally narrow which messages fire the trigger using + :attr:`event_filter` — a `JMESPath `_ expression + evaluated against each message dict. A message is kept when the + expression returns a truthy value. Examples:: + + subtype == null # regular user messages only + type == 'message' && user == 'U0123ABC' + contains(text, '@here') + + Required Slack OAuth scopes depend on the channel types polled: + ``channels:history``, ``groups:history``, ``im:history``, ``mpim:history``. + Bot tokens (``xoxb-…``) work but the bot must be a member of each + channel; user tokens (``xoxp-…``) generally don't need membership. + + Notes: + + - The Slack Web API returns HTTP 200 even for application errors; we + detect those via ``body["ok"]`` and log ``body["error"]``. + - On HTTP 429 (rate limit) we log the ``Retry-After`` value and treat + the channel as having no new messages this cycle. The next scheduler + tick will retry. + - Only the first page (up to 200 messages) is fetched per channel per + poll. Channels exceeding that between polls will skip the gap — + tighten ``scheduler_interval`` or shorten poll cycles to compensate. + """ + + _PAYLOAD_SOURCE: ClassVar[str] = "slack_trigger" + _PAYLOAD_KEY: ClassVar[str] = "messages" + _SLACK_MESSAGE_PAGE_LIMIT: ClassVar[int] = 200 + + type: Literal["slack"] = "slack" + slack_token: SecretStr = Field( + ..., + description=( + "Slack token used to authenticate against the Web API. Either a " + "user token ('xoxp-…') or a bot token ('xoxb-…'). The token " + "needs the appropriate '*:history' OAuth scopes for the " + "channel types being polled." + ), + ) + channels: list[str] = Field( + ..., + min_length=1, + description=( + "Slack channel IDs to poll (e.g. 'C0123ABC'). IDs — not names — " + "are required because names are not stable across renames. Find a " + "channel's ID in its 'About' panel inside the Slack client." + ), + ) + + @field_validator("channels") + @classmethod + def validate_channels(cls, v: list[str]) -> list[str]: + cleaned: list[str] = [] + for c in v: + c = c.strip() + if not _SLACK_CHANNEL_ID_RE.match(c): + raise ValueError( + f"Invalid Slack channel id {c!r}: expected uppercase " + "alphanumeric like 'C0123ABC' (channel IDs, not names)." + ) + cleaned.append(c) + return cleaned + + @field_serializer("slack_token", when_used="always") + def _serialize_token(self, v: SecretStr) -> str: + """Emit the raw secret so the trigger round-trips through the JSON column. + + See the corresponding note on :class:`GithubTrigger` — the token is + persisted in plain text because the scheduler must reuse it on every + poll. ``SecretStr`` still protects against accidental logging at the + application layer. + """ + return v.get_secret_value() + + def _build_client(self) -> httpx.AsyncClient: + """Construct an authenticated Slack Web API client.""" + return httpx.AsyncClient( + base_url="https://slack.com/api", + headers={ + "Accept": "application/json", + "User-Agent": "openhands-automation", + "Authorization": f"Bearer {self.slack_token.get_secret_value()}", + }, + timeout=30.0, + ) + + async def _fetch_new_messages( + self, + client: httpx.AsyncClient, + channel: str, + cutoff: datetime, + ) -> list[dict[str, Any]]: + """Return matching messages for ``channel`` newer than ``cutoff``. + + Uses Slack's ``oldest`` parameter (exclusive) for server-side + filtering. Errors — HTTP, JSON, rate-limit, or ``ok=false`` — + are logged and treated as "no new messages" so a single bad + channel cannot take down the whole trigger. + """ + oldest = f"{cutoff.timestamp():.6f}" + try: + resp = await client.get( + "/conversations.history", + params={ + "channel": channel, + "oldest": oldest, + "limit": self._SLACK_MESSAGE_PAGE_LIMIT, + }, + ) + except httpx.HTTPError as e: + logger.warning("Slack poll failed for %s: %s", channel, e) + return [] + + if resp.status_code == 429: + retry_after = resp.headers.get("Retry-After", "unknown") + logger.warning( + "Slack poll for %s rate-limited (Retry-After=%s)", + channel, + retry_after, + ) + return [] + if resp.status_code != 200: + logger.warning( + "Slack poll for %s returned status %s", + channel, + resp.status_code, + ) + return [] + + try: + body = resp.json() + except ValueError: + logger.warning("Slack poll for %s returned non-JSON body", channel) + return [] + if not isinstance(body, dict): + return [] + if not body.get("ok"): + logger.warning( + "Slack poll for %s returned error: %s", + channel, + body.get("error", "unknown"), + ) + return [] + + messages = body.get("messages") or [] + if not isinstance(messages, list): + return [] + + # Slack returns newest-first; deliver in natural chronological order. + messages_sorted = sorted(messages, key=lambda m: float(m.get("ts", "0") or "0")) + + new_messages: list[dict[str, Any]] = [] + for msg in messages_sorted: + if not isinstance(msg, dict): + continue + ts_raw = msg.get("ts") + if not isinstance(ts_raw, str): + continue + try: + ts_unix = float(ts_raw) + except ValueError: + continue + # Defense in depth — `oldest` is exclusive but check anyway. + if ts_unix <= cutoff.timestamp(): + continue + if not self._item_matches_filter(msg, channel): + continue + tagged = dict(msg) + tagged.setdefault("_channel", channel) + new_messages.append(tagged) + return new_messages + + async def _fetch_all_new( + self, + client: httpx.AsyncClient, + cutoff: datetime, + ) -> list[dict[str, Any]]: + from openhands.automation.utils.async_utils import wait_all + + per_channel: list[list[dict[str, Any]]] = await wait_all( + [self._fetch_new_messages(client, c, cutoff) for c in self.channels], + timeout=None, + ) + return [m for batch in per_channel for m in batch] + def _get_trigger_discriminator(v: dict | BaseModel) -> str: """Discriminator function for Pydantic's discriminated union. Returns the trigger type string, which Pydantic uses to select the - correct model (CronTrigger or EventTrigger) from the union. + correct model (CronTrigger, EventTrigger, GithubTrigger, or SlackTrigger) + from the union. Why sentinel instead of raising ValueError: Pydantic discriminator functions must return a string - they cannot raise exceptions. By returning an invalid sentinel value, Pydantic generates a proper ValidationError with context like: "Input tag '__missing_trigger_type__' found using 'type' does not - match any of the expected tags: 'cron', 'event'" + match any of the expected tags: 'cron', 'event', 'github', 'slack'" This produces a user-friendly 422 response via FastAPI. """ if isinstance(v, dict): @@ -197,10 +801,19 @@ def _get_trigger_discriminator(v: dict | BaseModel) -> str: # Union type for all triggers, using discriminated union Trigger = Annotated[ - Annotated[CronTrigger, Tag("cron")] | Annotated[EventTrigger, Tag("event")], + Annotated[CronTrigger, Tag("cron")] + | Annotated[EventTrigger, Tag("event")] + | Annotated[GithubTrigger, Tag("github")] + | Annotated[SlackTrigger, Tag("slack")], Discriminator(_get_trigger_discriminator), ] +# Reusable adapter for parsing trigger dicts (e.g. ``automation.trigger`` JSON) +# into the correct ``_TriggerBase`` subclass. +TriggerAdapter: TypeAdapter[ + CronTrigger | EventTrigger | GithubTrigger | SlackTrigger +] = TypeAdapter(Trigger) + class RunStatus(StrEnum): """Status of an automation run (for API responses).""" diff --git a/openhands/automation/utils/__init__.py b/openhands/automation/utils/__init__.py index 5aa4200..98acbaa 100644 --- a/openhands/automation/utils/__init__.py +++ b/openhands/automation/utils/__init__.py @@ -4,6 +4,7 @@ APIKeyError, get_api_key_for_automation_run, ) +from openhands.automation.utils.async_utils import AsyncException, wait_all from openhands.automation.utils.cron import ( get_next_fire_time, get_prev_fire_time, @@ -15,10 +16,12 @@ __all__ = [ "APIKeyError", + "AsyncException", "get_api_key_for_automation_run", "get_next_fire_time", "get_prev_fire_time", "is_automation_due", "log_extra", "utcnow", + "wait_all", ] diff --git a/openhands/automation/utils/async_utils.py b/openhands/automation/utils/async_utils.py new file mode 100644 index 0000000..4783aef --- /dev/null +++ b/openhands/automation/utils/async_utils.py @@ -0,0 +1,60 @@ +"""Async helpers for running coroutines concurrently. + +Adapted from +https://github.com/OpenHands/OpenHands/blob/main/openhands/app_server/utils/async_utils.py +""" + +from __future__ import annotations + +import asyncio +from collections.abc import Coroutine, Iterable +from typing import Any + + +GENERAL_TIMEOUT: int = 15 + + +class AsyncException(Exception): + """Raised by ``wait_all`` when more than one task raised an exception.""" + + def __init__(self, exceptions: list[BaseException]) -> None: + self.exceptions = exceptions + super().__init__("\n".join(str(e) for e in exceptions)) + + +async def wait_all( + iterable: Iterable[Coroutine[Any, Any, Any]], + timeout: float | None = GENERAL_TIMEOUT, +) -> list[Any]: + """Run the given coroutines concurrently and wait for them all to finish. + + Returns the results in the original order. If a single task raised an + exception it is re-raised. If multiple tasks raised, an + :class:`AsyncException` containing all of them is raised. If the timeout + elapses any still-pending tasks are cancelled and ``asyncio.TimeoutError`` + is raised. + """ + tasks = [asyncio.create_task(c) for c in iterable] + if not tasks: + return [] + + _, pending = await asyncio.wait(tasks, timeout=timeout) + if pending: + for task in pending: + task.cancel() + raise TimeoutError() + + results: list[Any] = [] + errors: list[BaseException] = [] + for task in tasks: + try: + results.append(task.result()) + except Exception as e: # noqa: BLE001 - propagated below + errors.append(e) + results.append(None) + + if errors: + if len(errors) == 1: + raise errors[0] + raise AsyncException(errors) + return results diff --git a/tests/conftest.py b/tests/conftest.py index 6099085..d3c0fec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -203,3 +203,111 @@ def mock_settings(): service_key="test-service-key", base_url="http://localhost:8000", ) + + +# --------------------------------------------------------------------------- +# Lightweight SQLite fixtures for trigger / scheduler-helper unit tests +# --------------------------------------------------------------------------- +# These bypass testcontainers/Docker for tests that only need a real session +# to write/read AutomationRun rows. The scheduler code uses ``using_sqlite()`` +# to gate Postgres-only features (``FOR UPDATE SKIP LOCKED``), so SQLite is a +# valid backend here. + + +@pytest.fixture +async def sqlite_engine(): + engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield engine + await engine.dispose() + + +@pytest.fixture +async def sqlite_session_factory(sqlite_engine): + return async_sessionmaker( + sqlite_engine, class_=AsyncSession, expire_on_commit=False + ) + + +@pytest.fixture +async def sqlite_session(sqlite_session_factory): + async with sqlite_session_factory() as session: + yield session + + +@pytest.fixture +def patch_github_transport(monkeypatch): + """Inject an ``httpx.MockTransport`` into ``GithubTrigger._build_client``. + + Returns a callable: given a responder ``(httpx.Request) -> httpx.Response``, + installs it and returns the list of requests the trigger ends up issuing. + """ + import httpx + + from openhands.automation.schemas import GithubTrigger + + def install(responder): + seen: list[httpx.Request] = [] + + def capture(request: httpx.Request) -> httpx.Response: + seen.append(request) + return responder(request) + + transport = httpx.MockTransport(capture) + + def _patched(self: GithubTrigger) -> httpx.AsyncClient: + return httpx.AsyncClient( + base_url="https://api.github.com", + headers={ + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", + "User-Agent": "openhands-automation", + "Authorization": ( + f"Bearer {self.github_access_token.get_secret_value()}" + ), + }, + transport=transport, + ) + + monkeypatch.setattr(GithubTrigger, "_build_client", _patched) + return seen + + return install + + +@pytest.fixture +def patch_slack_transport(monkeypatch): + """Inject an ``httpx.MockTransport`` into ``SlackTrigger._build_client``. + + Returns a callable: given a responder ``(httpx.Request) -> httpx.Response``, + installs it and returns the list of requests the trigger ends up issuing. + """ + import httpx + + from openhands.automation.schemas import SlackTrigger + + def install(responder): + seen: list[httpx.Request] = [] + + def capture(request: httpx.Request) -> httpx.Response: + seen.append(request) + return responder(request) + + transport = httpx.MockTransport(capture) + + def _patched(self: SlackTrigger) -> httpx.AsyncClient: + return httpx.AsyncClient( + base_url="https://slack.com/api", + headers={ + "Accept": "application/json", + "User-Agent": "openhands-automation", + "Authorization": f"Bearer {self.slack_token.get_secret_value()}", + }, + transport=transport, + ) + + monkeypatch.setattr(SlackTrigger, "_build_client", _patched) + return seen + + return install diff --git a/tests/test_async_utils.py b/tests/test_async_utils.py new file mode 100644 index 0000000..9539ab9 --- /dev/null +++ b/tests/test_async_utils.py @@ -0,0 +1,64 @@ +"""Tests for openhands.automation.utils.async_utils.""" + +import asyncio + +import pytest + +from openhands.automation.utils.async_utils import AsyncException, wait_all + + +class TestWaitAll: + async def test_empty_iterable_returns_empty_list(self): + assert await wait_all([]) == [] + + async def test_results_in_original_order(self): + async def producer(value: int, delay: float) -> int: + await asyncio.sleep(delay) + return value + + # The slowest coroutine is first; result order must still match input. + results = await wait_all( + [producer(1, 0.03), producer(2, 0.0), producer(3, 0.01)] + ) + assert results == [1, 2, 3] + + async def test_runs_concurrently(self): + # Three 50ms sleeps must complete well under 150ms (serial baseline). + async def sleeper() -> int: + await asyncio.sleep(0.05) + return 1 + + loop = asyncio.get_running_loop() + start = loop.time() + results = await wait_all([sleeper(), sleeper(), sleeper()]) + elapsed = loop.time() - start + assert results == [1, 1, 1] + assert elapsed < 0.13 + + async def test_single_exception_propagates(self): + async def boom() -> None: + raise RuntimeError("kapow") + + async def ok() -> int: + return 42 + + with pytest.raises(RuntimeError, match="kapow"): + await wait_all([ok(), boom()]) + + async def test_multiple_exceptions_wrapped(self): + async def one() -> None: + raise ValueError("a") + + async def two() -> None: + raise ValueError("b") + + with pytest.raises(AsyncException) as exc: + await wait_all([one(), two()]) + assert len(exc.value.exceptions) == 2 + + async def test_timeout_cancels_pending(self): + async def slow() -> None: + await asyncio.sleep(1) + + with pytest.raises(asyncio.TimeoutError): + await wait_all([slow()], timeout=0.05) diff --git a/tests/test_scheduler_create_pending_runs.py b/tests/test_scheduler_create_pending_runs.py new file mode 100644 index 0000000..a2a3aca --- /dev/null +++ b/tests/test_scheduler_create_pending_runs.py @@ -0,0 +1,191 @@ +"""Tests for ``scheduler._create_pending_runs``. + +Exercises the scheduler's per-batch trigger dispatch loop against a real +(SQLite in-memory) session, validating the contract documented on +``_TriggerBase.create_pending_run``: return None ⇒ no run created, return +``AutomationRun`` ⇒ appended to the result list (and persisted by the +trigger itself). + +Only the GitHub HTTP transport is replaced; everything else is real code. +""" + +from __future__ import annotations + +import uuid +from datetime import UTC, datetime, timedelta + +import httpx +import pytest +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from openhands.automation.models import Automation, AutomationRun +from openhands.automation.scheduler import _create_pending_runs + + +TEST_USER_ID = uuid.UUID("12345678-1234-5678-1234-567812345678") +TEST_ORG_ID = uuid.UUID("87654321-4321-8765-4321-876543218765") + + +async def _make( + session: AsyncSession, + *, + trigger: dict, + enabled: bool = True, + created_at: datetime | None = None, + last_triggered_at: datetime | None = None, +) -> Automation: + automation = Automation( + id=uuid.uuid4(), + user_id=TEST_USER_ID, + org_id=TEST_ORG_ID, + name="Test", + trigger=trigger, + tarball_path="s3://bucket/code.tar.gz", + entrypoint="uv run main.py", + enabled=enabled, + deleted_at=None, + created_at=created_at or datetime(2026, 1, 1, tzinfo=UTC), + last_triggered_at=last_triggered_at, + ) + session.add(automation) + await session.flush() + return automation + + +class TestCreatePendingRuns: + async def test_empty_input(self, sqlite_session): + assert ( + await _create_pending_runs( + sqlite_session, [], datetime(2026, 1, 1, tzinfo=UTC) + ) + == [] + ) + + async def test_mixed_triggers_only_due_ones_create_runs( + self, sqlite_session, patch_github_transport + ): + cutoff = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + now = cutoff + timedelta(minutes=5) + + # Github API: events for `org/has-events`, nothing for `org/quiet`. + def responder(req: httpx.Request) -> httpx.Response: + if "has-events" in req.url.path: + return httpx.Response( + 200, + json=[ + { + "id": "1", + "type": "PushEvent", + "created_at": (cutoff + timedelta(minutes=1)).strftime( + "%Y-%m-%dT%H:%M:%SZ" + ), + } + ], + ) + return httpx.Response(200, json=[]) + + patch_github_transport(responder) + + cron_due = await _make( + sqlite_session, + trigger={"type": "cron", "schedule": "0,30 * * * *", "timezone": "UTC"}, + created_at=datetime(2026, 3, 15, 10, 25, tzinfo=UTC), + ) + cron_not_due = await _make( + sqlite_session, + trigger={"type": "cron", "schedule": "0,30 * * * *", "timezone": "UTC"}, + created_at=datetime(2026, 3, 15, 12, 4, tzinfo=UTC), + ) + event_not_due = await _make( + sqlite_session, + trigger={ + "type": "event", + "source": "github", + "on": "pull_request.opened", + }, + ) + gh_due = await _make( + sqlite_session, + trigger={ + "type": "github", + "github_access_token": "ghp_xxx", + "repositories": ["org/has-events"], + }, + last_triggered_at=cutoff, + ) + gh_not_due = await _make( + sqlite_session, + trigger={ + "type": "github", + "github_access_token": "ghp_xxx", + "repositories": ["org/quiet"], + }, + last_triggered_at=cutoff, + ) + invalid = await _make(sqlite_session, trigger={"type": "totally-not-a-trigger"}) + + runs = await _create_pending_runs( + sqlite_session, + [cron_due, cron_not_due, event_not_due, gh_due, gh_not_due, invalid], + now, + ) + await sqlite_session.commit() + + run_automation_ids = {r.automation_id for r in runs} + assert run_automation_ids == {cron_due.id, gh_due.id} + + # GitHub run carries its event payload; cron run does not. + runs_by_aid = {r.automation_id: r for r in runs} + assert runs_by_aid[gh_due.id].event_payload is not None + assert runs_by_aid[gh_due.id].event_payload["source"] == "github_trigger" + assert len(runs_by_aid[gh_due.id].event_payload["events"]) == 1 + assert runs_by_aid[cron_due.id].event_payload is None + + # And the runs are durably persisted. + persisted_ids = { + r.automation_id + for r in (await sqlite_session.execute(select(AutomationRun))) + .scalars() + .all() + } + assert persisted_ids == {cron_due.id, gh_due.id} + + async def test_failing_trigger_does_not_block_others( + self, sqlite_session, patch_github_transport + ): + """A trigger raising must NOT stop other triggers from creating runs.""" + cutoff = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + now = cutoff + timedelta(minutes=5) + + # GitHub transport raises a connection error → the trigger's internal + # error handling logs and returns no events → no run, no exception. + # Then a SECOND github trigger explodes outright by raising from the + # transport responder again — both should be tolerated. + patch_github_transport( + lambda req: (_ for _ in ()).throw(httpx.ConnectError("no route")) + ) + + cron_due = await _make( + sqlite_session, + trigger={"type": "cron", "schedule": "0,30 * * * *", "timezone": "UTC"}, + created_at=datetime(2026, 3, 15, 10, 25, tzinfo=UTC), + ) + gh_broken = await _make( + sqlite_session, + trigger={ + "type": "github", + "github_access_token": "ghp_xxx", + "repositories": ["org/exploding"], + }, + last_triggered_at=cutoff, + ) + + runs = await _create_pending_runs(sqlite_session, [cron_due, gh_broken], now) + await sqlite_session.commit() + + assert [r.automation_id for r in runs] == [cron_due.id] + + +if __name__ == "__main__": # pragma: no cover + pytest.main([__file__]) diff --git a/tests/test_triggers.py b/tests/test_triggers.py new file mode 100644 index 0000000..4a7ce37 --- /dev/null +++ b/tests/test_triggers.py @@ -0,0 +1,838 @@ +"""Tests for the trigger schemas and their ``create_pending_run`` methods. + +These tests exercise the real :class:`CronTrigger`/:class:`EventTrigger`/ +:class:`GithubTrigger`/:class:`SlackTrigger` code paths against a real +(SQLite in-memory) session, so +:func:`openhands.automation.utils.run.create_pending_run` runs against a +real database. Only the GitHub/Slack HTTP transports are replaced via +``httpx.MockTransport`` (see the ``patch_github_transport`` / +``patch_slack_transport`` fixtures); the trigger code — including +header/auth construction, response parsing, timestamp filtering, and +parallel resource fan-out — is untouched. +""" + +from __future__ import annotations + +import uuid +from datetime import UTC, datetime, timedelta +from typing import Any + +import pytest +from pydantic import SecretStr, ValidationError +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from openhands.automation.models import Automation, AutomationRun +from openhands.automation.schemas import ( + EventTrigger, + GithubTrigger, + SlackTrigger, + TriggerAdapter, +) + + +# --- helpers ---------------------------------------------------------------- + + +TEST_USER_ID = uuid.UUID("12345678-1234-5678-1234-567812345678") +TEST_ORG_ID = uuid.UUID("87654321-4321-8765-4321-876543218765") + + +async def _make_automation( + session: AsyncSession, + *, + trigger: dict[str, Any], + enabled: bool = True, + deleted_at: datetime | None = None, + created_at: datetime | None = None, + last_triggered_at: datetime | None = None, +) -> Automation: + automation = Automation( + id=uuid.uuid4(), + user_id=TEST_USER_ID, + org_id=TEST_ORG_ID, + name="Test", + trigger=trigger, + tarball_path="s3://bucket/code.tar.gz", + entrypoint="uv run main.py", + enabled=enabled, + deleted_at=deleted_at, + created_at=created_at or datetime(2026, 1, 1, tzinfo=UTC), + last_triggered_at=last_triggered_at, + ) + session.add(automation) + await session.flush() + return automation + + +def _gh_event( + event_id: int, created_at: datetime, event_type: str = "PushEvent" +) -> dict[str, Any]: + return { + "id": str(event_id), + "type": event_type, + "created_at": created_at.strftime("%Y-%m-%dT%H:%M:%SZ"), + } + + +# --------------------------------------------------------------------------- +# CronTrigger.create_pending_run +# --------------------------------------------------------------------------- + + +class TestCronTriggerCreatesRun: + async def test_creates_run_when_due(self, sqlite_session): + trigger_cfg = { + "type": "cron", + "schedule": "0,30 * * * *", + "timezone": "UTC", + } + # Every 30 min; created 10 min before the 10:30 fire window. + automation = await _make_automation( + sqlite_session, + trigger=trigger_cfg, + created_at=datetime(2026, 3, 15, 10, 25, tzinfo=UTC), + ) + trigger = TriggerAdapter.validate_python(trigger_cfg) + now = datetime(2026, 3, 15, 10, 35, tzinfo=UTC) + + run = await trigger.create_pending_run(sqlite_session, automation, now) + await sqlite_session.commit() + + assert run is not None + assert run.automation_id == automation.id + # Round-trip via DB to confirm the row was persisted. + from_db = ( + ( + await sqlite_session.execute( + select(AutomationRun).where(AutomationRun.id == run.id) + ) + ) + .scalars() + .first() + ) + assert from_db is not None + # Cron triggers don't attach an event payload. + assert from_db.event_payload is None + # The util bumped last_triggered_at on the automation. + assert automation.last_triggered_at is not None + + async def test_returns_none_when_not_due(self, sqlite_session): + trigger_cfg = { + "type": "cron", + "schedule": "0,30 * * * *", + "timezone": "UTC", + } + # Automation was JUST created — the most recent fire (10:30) is BEFORE + # creation (10:34), so it shouldn't fire yet. + automation = await _make_automation( + sqlite_session, + trigger=trigger_cfg, + created_at=datetime(2026, 3, 15, 10, 34, tzinfo=UTC), + ) + trigger = TriggerAdapter.validate_python(trigger_cfg) + now = datetime(2026, 3, 15, 10, 35, tzinfo=UTC) + + assert await trigger.create_pending_run(sqlite_session, automation, now) is None + + async def test_returns_none_when_disabled(self, sqlite_session): + trigger_cfg = {"type": "cron", "schedule": "* * * * *", "timezone": "UTC"} + automation = await _make_automation( + sqlite_session, trigger=trigger_cfg, enabled=False + ) + trigger = TriggerAdapter.validate_python(trigger_cfg) + assert await trigger.create_pending_run(sqlite_session, automation) is None + + +# --------------------------------------------------------------------------- +# EventTrigger.create_pending_run +# --------------------------------------------------------------------------- + + +class TestEventTriggerNeverFires: + async def test_polling_never_creates_a_run(self, sqlite_session): + trigger_cfg = { + "type": "event", + "source": "github", + "on": "pull_request.opened", + } + automation = await _make_automation(sqlite_session, trigger=trigger_cfg) + trigger = EventTrigger.model_validate(trigger_cfg) + assert await trigger.create_pending_run(sqlite_session, automation) is None + + +# --------------------------------------------------------------------------- +# GithubTrigger — validation +# --------------------------------------------------------------------------- + + +class TestGithubTriggerValidation: + def test_minimum_valid_config(self): + trigger = GithubTrigger( + github_access_token=SecretStr("ghp_xxx"), + repositories=["All-Hands-AI/OpenHands"], + ) + assert trigger.type == "github" + assert trigger.repositories == ["All-Hands-AI/OpenHands"] + # Secret never leaks via repr. + assert "ghp_xxx" not in repr(trigger) + assert trigger.github_access_token.get_secret_value() == "ghp_xxx" + + def test_invalid_repository_format_rejected(self): + with pytest.raises(ValidationError, match="Invalid repository"): + GithubTrigger( + github_access_token=SecretStr("ghp_xxx"), + repositories=["not-a-repo"], + ) + + def test_empty_repositories_rejected(self): + with pytest.raises(ValidationError): + GithubTrigger( + github_access_token=SecretStr("ghp_xxx"), + repositories=[], + ) + + def test_valid_jmespath_event_filter_is_accepted(self): + trigger = GithubTrigger( + github_access_token=SecretStr("ghp_xxx"), + repositories=["foo/bar"], + event_filter="type == 'PushEvent' && payload.ref == 'refs/heads/main'", + ) + assert ( + trigger.event_filter + == "type == 'PushEvent' && payload.ref == 'refs/heads/main'" + ) + + def test_invalid_jmespath_event_filter_rejected(self): + with pytest.raises(ValidationError, match="Invalid JMESPath expression"): + GithubTrigger( + github_access_token=SecretStr("ghp_xxx"), + repositories=["foo/bar"], + event_filter="this is not valid jmespath @@@", + ) + + def test_empty_event_filter_becomes_none(self): + # Whitespace-only filter shouldn't error and shouldn't be retained. + trigger = GithubTrigger( + github_access_token=SecretStr("ghp_xxx"), + repositories=["foo/bar"], + event_filter=" ", + ) + assert trigger.event_filter is None + + def test_discriminated_union_dispatches_to_github(self): + parsed = TriggerAdapter.validate_python( + { + "type": "github", + "github_access_token": "ghp_yyy", + "repositories": ["foo/bar"], + } + ) + assert isinstance(parsed, GithubTrigger) + + +# --------------------------------------------------------------------------- +# GithubTrigger.create_pending_run +# --------------------------------------------------------------------------- + + +class TestGithubTriggerCreatesRun: + async def test_fires_and_attaches_events_to_payload( + self, sqlite_session, patch_github_transport + ): + cutoff = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + fresh = _gh_event(2, cutoff + timedelta(minutes=5)) + seen = patch_github_transport( + lambda req: __import__("httpx").Response(200, json=[fresh]) + ) + + trigger = GithubTrigger( + github_access_token=SecretStr("ghp_xxx"), + repositories=["foo/bar"], + ) + automation = await _make_automation( + sqlite_session, + trigger=trigger.model_dump(mode="python"), + last_triggered_at=cutoff, + ) + + run = await trigger.create_pending_run(sqlite_session, automation) + await sqlite_session.commit() + + assert run is not None + assert run.automation_id == automation.id + assert run.event_payload is not None + assert run.event_payload["source"] == "github_trigger" + assert len(run.event_payload["events"]) == 1 + event = run.event_payload["events"][0] + assert event["id"] == "2" + assert event["type"] == "PushEvent" + # The trigger tags each event with the source repo. + assert event["_repository"] == "foo/bar" + + # Sanity-check the outgoing HTTP request. + assert len(seen) == 1 + assert seen[0].url.path == "/repos/foo/bar/events" + assert seen[0].headers["Authorization"] == "Bearer ghp_xxx" + + async def test_returns_none_when_no_new_events( + self, sqlite_session, patch_github_transport + ): + cutoff = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + # All events are older than cutoff. + stale = _gh_event(1, cutoff - timedelta(minutes=10)) + patch_github_transport( + lambda req: __import__("httpx").Response(200, json=[stale]) + ) + + trigger = GithubTrigger( + github_access_token=SecretStr("ghp_xxx"), + repositories=["foo/bar"], + ) + automation = await _make_automation( + sqlite_session, + trigger=trigger.model_dump(mode="python"), + last_triggered_at=cutoff, + ) + + assert await trigger.create_pending_run(sqlite_session, automation) is None + # No run was persisted. + runs = (await sqlite_session.execute(select(AutomationRun))).scalars().all() + assert runs == [] + + async def test_jmespath_filter_excludes_non_matching( + self, sqlite_session, patch_github_transport + ): + """``event_filter`` drops events whose JMESPath expression is falsy.""" + cutoff = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + events = [_gh_event(2, cutoff + timedelta(minutes=5), event_type="IssuesEvent")] + patch_github_transport( + lambda req: __import__("httpx").Response(200, json=events) + ) + + trigger = GithubTrigger( + github_access_token=SecretStr("ghp_xxx"), + repositories=["foo/bar"], + event_filter="type == 'PushEvent'", + ) + automation = await _make_automation( + sqlite_session, + trigger=trigger.model_dump(mode="python"), + last_triggered_at=cutoff, + ) + assert await trigger.create_pending_run(sqlite_session, automation) is None + + async def test_jmespath_filter_keeps_matching( + self, sqlite_session, patch_github_transport + ): + """A truthy JMESPath result keeps the event and fires the trigger.""" + cutoff = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + events = [ + _gh_event(7, cutoff + timedelta(minutes=1), event_type="IssuesEvent"), + _gh_event(8, cutoff + timedelta(minutes=2), event_type="PushEvent"), + ] + patch_github_transport( + lambda req: __import__("httpx").Response(200, json=events) + ) + + trigger = GithubTrigger( + github_access_token=SecretStr("ghp_xxx"), + repositories=["foo/bar"], + event_filter="type == 'PushEvent'", + ) + automation = await _make_automation( + sqlite_session, + trigger=trigger.model_dump(mode="python"), + last_triggered_at=cutoff, + ) + run = await trigger.create_pending_run(sqlite_session, automation) + assert run is not None + # Only the PushEvent survived the filter. + kept = run.event_payload["events"] + assert [e["id"] for e in kept] == ["8"] + + async def test_jmespath_filter_supports_nested_payload( + self, sqlite_session, patch_github_transport + ): + """JMESPath can match against the nested ``payload`` object.""" + cutoff = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + events = [ + { + "id": "9", + "type": "PullRequestEvent", + "created_at": (cutoff + timedelta(minutes=1)).strftime( + "%Y-%m-%dT%H:%M:%SZ" + ), + "payload": {"action": "closed"}, + }, + { + "id": "10", + "type": "PullRequestEvent", + "created_at": (cutoff + timedelta(minutes=2)).strftime( + "%Y-%m-%dT%H:%M:%SZ" + ), + "payload": {"action": "opened"}, + }, + ] + patch_github_transport( + lambda req: __import__("httpx").Response(200, json=events) + ) + + trigger = GithubTrigger( + github_access_token=SecretStr("ghp_xxx"), + repositories=["foo/bar"], + event_filter=("type == 'PullRequestEvent' && payload.action == 'opened'"), + ) + automation = await _make_automation( + sqlite_session, + trigger=trigger.model_dump(mode="python"), + last_triggered_at=cutoff, + ) + run = await trigger.create_pending_run(sqlite_session, automation) + assert run is not None + kept = run.event_payload["events"] + assert [e["id"] for e in kept] == ["10"] + + async def test_uses_created_at_when_never_triggered( + self, sqlite_session, patch_github_transport + ): + created = datetime(2026, 3, 15, 9, 0, 0, tzinfo=UTC) + events = [_gh_event(3, created + timedelta(hours=1))] + patch_github_transport( + lambda req: __import__("httpx").Response(200, json=events) + ) + + trigger = GithubTrigger( + github_access_token=SecretStr("ghp_xxx"), + repositories=["foo/bar"], + ) + automation = await _make_automation( + sqlite_session, + trigger=trigger.model_dump(mode="python"), + created_at=created, + last_triggered_at=None, + ) + + run = await trigger.create_pending_run(sqlite_session, automation) + assert run is not None + assert len(run.event_payload["events"]) == 1 + + async def test_disabled_short_circuits_before_http( + self, sqlite_session, patch_github_transport + ): + seen = patch_github_transport( + lambda req: __import__("httpx").Response(200, json=[]) + ) + trigger = GithubTrigger( + github_access_token=SecretStr("ghp_xxx"), + repositories=["foo/bar"], + ) + automation = await _make_automation( + sqlite_session, + trigger=trigger.model_dump(mode="python"), + enabled=False, + ) + assert await trigger.create_pending_run(sqlite_session, automation) is None + assert seen == [] # Must short-circuit BEFORE hitting the network. + + async def test_non_200_response_is_not_due( + self, sqlite_session, patch_github_transport + ): + cutoff = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + patch_github_transport( + lambda req: __import__("httpx").Response( + 403, json={"message": "rate limit"} + ) + ) + trigger = GithubTrigger( + github_access_token=SecretStr("ghp_xxx"), + repositories=["foo/bar"], + ) + automation = await _make_automation( + sqlite_session, + trigger=trigger.model_dump(mode="python"), + last_triggered_at=cutoff, + ) + assert await trigger.create_pending_run(sqlite_session, automation) is None + + async def test_collects_events_across_repos( + self, sqlite_session, patch_github_transport + ): + cutoff = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + import httpx + + def responder(req: httpx.Request) -> httpx.Response: + if "foo/bar" in req.url.path: + return httpx.Response( + 200, + json=[_gh_event(1, cutoff + timedelta(minutes=1))], + ) + return httpx.Response( + 200, + json=[_gh_event(2, cutoff + timedelta(minutes=2))], + ) + + seen = patch_github_transport(responder) + trigger = GithubTrigger( + github_access_token=SecretStr("ghp_xxx"), + repositories=["foo/bar", "foo/baz"], + ) + automation = await _make_automation( + sqlite_session, + trigger=trigger.model_dump(mode="python"), + last_triggered_at=cutoff, + ) + + run = await trigger.create_pending_run(sqlite_session, automation) + assert run is not None + # Both repos contributed an event. + events = run.event_payload["events"] + assert len(events) == 2 + repos = {ev["_repository"] for ev in events} + assert repos == {"foo/bar", "foo/baz"} + # Both repos were polled (parallel fan-out). + polled = {r.url.path for r in seen} + assert polled == {"/repos/foo/bar/events", "/repos/foo/baz/events"} + + +# --------------------------------------------------------------------------- +# SlackTrigger +# --------------------------------------------------------------------------- + + +def _slack_message( + msg_id: str, + ts_dt: datetime, + *, + text: str = "hello", + user: str = "U0123ABC", + msg_type: str = "message", + subtype: str | None = None, +) -> dict[str, Any]: + """Build a fake Slack message dict. + + Slack ``ts`` is a string of unix seconds with microsecond precision. + """ + msg: dict[str, Any] = { + "type": msg_type, + "user": user, + "text": text, + "ts": f"{ts_dt.timestamp():.6f}", + "client_msg_id": msg_id, + } + if subtype is not None: + msg["subtype"] = subtype + return msg + + +def _slack_ok(messages: list[dict[str, Any]]) -> dict[str, Any]: + return {"ok": True, "messages": messages, "has_more": False} + + +class TestSlackTriggerValidation: + def test_valid_channel_ids_are_accepted(self): + trigger = SlackTrigger( + slack_token=SecretStr("xoxb-xxx"), + channels=["C0123ABC", "D0SOMEDM"], + ) + assert trigger.channels == ["C0123ABC", "D0SOMEDM"] + + def test_lowercase_channel_id_rejected(self): + with pytest.raises(ValidationError, match="Invalid Slack channel id"): + SlackTrigger( + slack_token=SecretStr("xoxb-xxx"), + channels=["c0123abc"], + ) + + def test_channel_name_rejected(self): + # '#general' isn't a stable id; ensure we reject names early. + with pytest.raises(ValidationError, match="Invalid Slack channel id"): + SlackTrigger( + slack_token=SecretStr("xoxb-xxx"), + channels=["#general"], + ) + + def test_empty_channels_rejected(self): + with pytest.raises(ValidationError): + SlackTrigger(slack_token=SecretStr("xoxb-xxx"), channels=[]) + + def test_invalid_jmespath_event_filter_rejected(self): + # The JMESPath validation lives on the shared polling base, so this + # also exercises that SlackTrigger inherits it correctly. + with pytest.raises(ValidationError, match="Invalid JMESPath expression"): + SlackTrigger( + slack_token=SecretStr("xoxb-xxx"), + channels=["C0123ABC"], + event_filter="not valid @@@", + ) + + def test_token_round_trips_through_model_dump(self): + """``model_dump`` must emit the raw string so the trigger can be + persisted in the AutomationsTrigger JSON column and rehydrated.""" + trigger = SlackTrigger( + slack_token=SecretStr("xoxb-supersecret"), + channels=["C0123ABC"], + ) + dumped = trigger.model_dump(mode="python") + assert dumped["slack_token"] == "xoxb-supersecret" + # Round-trip through the discriminated union too. + roundtripped = TriggerAdapter.validate_python(dumped) + assert isinstance(roundtripped, SlackTrigger) + assert roundtripped.slack_token.get_secret_value() == "xoxb-supersecret" + + def test_discriminated_union_dispatches_to_slack(self): + parsed = TriggerAdapter.validate_python( + { + "type": "slack", + "slack_token": "xoxb-xxx", + "channels": ["C0123ABC"], + } + ) + assert isinstance(parsed, SlackTrigger) + + +class TestSlackTriggerCreatesRun: + async def test_new_message_creates_run_with_payload( + self, sqlite_session, patch_slack_transport + ): + import httpx + + cutoff = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + msg = _slack_message("m1", cutoff + timedelta(seconds=30)) + seen = patch_slack_transport( + lambda req: httpx.Response(200, json=_slack_ok([msg])) + ) + + trigger = SlackTrigger( + slack_token=SecretStr("xoxb-xxx"), + channels=["C0123ABC"], + ) + automation = await _make_automation( + sqlite_session, + trigger=trigger.model_dump(mode="python"), + last_triggered_at=cutoff, + ) + + run = await trigger.create_pending_run(sqlite_session, automation) + assert run is not None + assert run.event_payload["source"] == "slack_trigger" + kept = run.event_payload["messages"] + assert [m["client_msg_id"] for m in kept] == ["m1"] + # Channel was tagged onto the message. + assert kept[0]["_channel"] == "C0123ABC" + # The Slack `oldest` param uses unix seconds derived from the cutoff. + assert len(seen) == 1 + assert seen[0].url.path == "/api/conversations.history" + assert seen[0].url.params["channel"] == "C0123ABC" + assert float(seen[0].url.params["oldest"]) == pytest.approx(cutoff.timestamp()) + # And the run was persisted. + rows = (await sqlite_session.execute(select(AutomationRun))).scalars().all() + assert len(rows) == 1 + + async def test_no_new_messages_returns_none( + self, sqlite_session, patch_slack_transport + ): + import httpx + + cutoff = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + patch_slack_transport(lambda req: httpx.Response(200, json=_slack_ok([]))) + + trigger = SlackTrigger( + slack_token=SecretStr("xoxb-xxx"), + channels=["C0123ABC"], + ) + automation = await _make_automation( + sqlite_session, + trigger=trigger.model_dump(mode="python"), + last_triggered_at=cutoff, + ) + assert await trigger.create_pending_run(sqlite_session, automation) is None + rows = (await sqlite_session.execute(select(AutomationRun))).scalars().all() + assert rows == [] + + async def test_slack_ok_false_treated_as_no_messages( + self, sqlite_session, patch_slack_transport + ): + """Slack returns HTTP 200 for app errors; ``ok=false`` must not fire.""" + import httpx + + cutoff = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + patch_slack_transport( + lambda req: httpx.Response( + 200, json={"ok": False, "error": "channel_not_found"} + ) + ) + + trigger = SlackTrigger( + slack_token=SecretStr("xoxb-xxx"), + channels=["C0123ABC"], + ) + automation = await _make_automation( + sqlite_session, + trigger=trigger.model_dump(mode="python"), + last_triggered_at=cutoff, + ) + assert await trigger.create_pending_run(sqlite_session, automation) is None + + async def test_rate_limit_treated_as_no_messages( + self, sqlite_session, patch_slack_transport + ): + """HTTP 429 (rate limit) is logged and treated as no new messages.""" + import httpx + + cutoff = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + patch_slack_transport( + lambda req: httpx.Response(429, headers={"Retry-After": "5"}, json={}) + ) + + trigger = SlackTrigger( + slack_token=SecretStr("xoxb-xxx"), + channels=["C0123ABC"], + ) + automation = await _make_automation( + sqlite_session, + trigger=trigger.model_dump(mode="python"), + last_triggered_at=cutoff, + ) + assert await trigger.create_pending_run(sqlite_session, automation) is None + + async def test_message_at_or_before_cutoff_excluded( + self, sqlite_session, patch_slack_transport + ): + """A message exactly at the cutoff or older must not fire the trigger.""" + import httpx + + cutoff = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + # Slack's `oldest` is server-side exclusive — but it's perfectly + # reasonable for the API to still hand us a message at the cutoff + # (e.g. if the float rounds). Our client-side check must reject it. + stale = _slack_message("m_old", cutoff) + patch_slack_transport(lambda req: httpx.Response(200, json=_slack_ok([stale]))) + + trigger = SlackTrigger( + slack_token=SecretStr("xoxb-xxx"), + channels=["C0123ABC"], + ) + automation = await _make_automation( + sqlite_session, + trigger=trigger.model_dump(mode="python"), + last_triggered_at=cutoff, + ) + assert await trigger.create_pending_run(sqlite_session, automation) is None + + async def test_jmespath_filter_keeps_matching_messages( + self, sqlite_session, patch_slack_transport + ): + """A truthy JMESPath result keeps the message and fires the trigger.""" + import httpx + + cutoff = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + msgs = [ + _slack_message( + "join", + cutoff + timedelta(seconds=10), + subtype="channel_join", + text="<@U1> has joined", + ), + _slack_message( + "real", + cutoff + timedelta(seconds=20), + text="real user message", + ), + ] + patch_slack_transport(lambda req: httpx.Response(200, json=_slack_ok(msgs))) + + # Filter out join/leave noise — only regular user messages. + trigger = SlackTrigger( + slack_token=SecretStr("xoxb-xxx"), + channels=["C0123ABC"], + event_filter="subtype == null", + ) + automation = await _make_automation( + sqlite_session, + trigger=trigger.model_dump(mode="python"), + last_triggered_at=cutoff, + ) + run = await trigger.create_pending_run(sqlite_session, automation) + assert run is not None + kept = run.event_payload["messages"] + assert [m["client_msg_id"] for m in kept] == ["real"] + + async def test_jmespath_filter_excludes_all( + self, sqlite_session, patch_slack_transport + ): + """A filter that matches nothing must not fire the trigger.""" + import httpx + + cutoff = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + msgs = [ + _slack_message("a", cutoff + timedelta(seconds=10), user="U1"), + _slack_message("b", cutoff + timedelta(seconds=20), user="U2"), + ] + patch_slack_transport(lambda req: httpx.Response(200, json=_slack_ok(msgs))) + + trigger = SlackTrigger( + slack_token=SecretStr("xoxb-xxx"), + channels=["C0123ABC"], + event_filter="user == 'UNOBODY'", + ) + automation = await _make_automation( + sqlite_session, + trigger=trigger.model_dump(mode="python"), + last_triggered_at=cutoff, + ) + assert await trigger.create_pending_run(sqlite_session, automation) is None + + async def test_multiple_channels_polled_in_parallel( + self, sqlite_session, patch_slack_transport + ): + """Two configured channels are each polled; both contribute messages.""" + import httpx + + cutoff = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + + def responder(request: httpx.Request) -> httpx.Response: + ch = request.url.params["channel"] + msg = _slack_message(f"msg_{ch}", cutoff + timedelta(seconds=5)) + return httpx.Response(200, json=_slack_ok([msg])) + + seen = patch_slack_transport(responder) + + trigger = SlackTrigger( + slack_token=SecretStr("xoxb-xxx"), + channels=["C0123ABC", "C0456DEF"], + ) + automation = await _make_automation( + sqlite_session, + trigger=trigger.model_dump(mode="python"), + last_triggered_at=cutoff, + ) + + run = await trigger.create_pending_run(sqlite_session, automation) + assert run is not None + kept = run.event_payload["messages"] + assert {m["_channel"] for m in kept} == {"C0123ABC", "C0456DEF"} + polled = {r.url.params["channel"] for r in seen} + assert polled == {"C0123ABC", "C0456DEF"} + + async def test_disabled_short_circuits_without_calling_slack( + self, sqlite_session, patch_slack_transport + ): + """A disabled automation must not call Slack at all.""" + import httpx + + seen = patch_slack_transport( + lambda req: httpx.Response(200, json=_slack_ok([])) + ) + + trigger = SlackTrigger( + slack_token=SecretStr("xoxb-xxx"), + channels=["C0123ABC"], + ) + automation = await _make_automation( + sqlite_session, + trigger=trigger.model_dump(mode="python"), + enabled=False, + ) + assert await trigger.create_pending_run(sqlite_session, automation) is None + assert seen == []