From 3a06954bd04ec9cc7e62b7a27e63844203092d56 Mon Sep 17 00:00:00 2001 From: openhands Date: Fri, 15 May 2026 20:08:02 +0000 Subject: [PATCH 1/3] feat: add GithubTrigger and refactor trigger interface around create_pending_run Adds a new poll-based trigger type that fires when configured GitHub repositories receive new events. Refactors the trigger abstraction so each subclass owns the decision to fire AND the resulting AutomationRun, which lets GithubTrigger attach the triggering events to run.event_payload. Trigger interface (openhands/automation/schemas.py) --------------------------------------------------- - New _TriggerBase abstract class with async create_pending_run(session, automation, now=None) -> AutomationRun | None. None means 'not due'. - CronTrigger delegates to the existing cron evaluator, then calls utils.run.create_pending_run on a fire. - EventTrigger always returns None (webhook-driven, never polled). - GithubTrigger: * github_access_token (SecretStr) + repositories (list[str]) + optional event_types allow-list. * Polls /repos/{owner}/{repo}/events for every configured repo concurrently via wait_all. * Collects every matching event newer than last_triggered_at (or created_at for the first poll) and attaches them to run.event_payload as {source: 'github_trigger', events: [...]}. * Each event is tagged with its source _repository. - Added @field_serializer for github_access_token so SecretStr round-trips through the JSON column (fixes a latent persistence crash). - Added TriggerAdapter (TypeAdapter[Trigger]) so callers can validate the discriminated union directly. Concurrency helper (openhands/automation/utils/async_utils.py) -------------------------------------------------------------- - New wait_all() that mirrors openhands SDK's async_utils.wait_all: runs an iterable of coroutines concurrently, preserves input order, aggregates exceptions, supports optional timeout. Scheduler (openhands/automation/scheduler.py) --------------------------------------------- - Replaced the inline cron-only is_automation_due filter at L120 with _create_pending_runs(session, automations, now), which: * Parses each automation.trigger JSON into a typed trigger model (ValidationError -> log and skip). * Awaits trigger.create_pending_run sequentially (shared AsyncSession is not safe for concurrent use; cross-trigger parallelism still happens inside each trigger). * Skips per-trigger exceptions so one bad trigger cannot starve the batch. - Generalised the firing log line (no longer assumes cron-only trigger shape). Tests ----- - tests/test_async_utils.py: ordering, concurrency, exception aggregation, timeout cancellation. - tests/test_triggers.py: cron/event/github create_pending_run paths, GithubTrigger validation, event-type filtering, never-triggered cutoff, disabled short-circuit, non-200 handling, multi-repo event collection. Uses real SQLite-in-memory sessions; only the GitHub HTTP transport is replaced via httpx.MockTransport. - tests/test_scheduler_create_pending_runs.py: mixed-trigger batches, invalid trigger config skipped, one failing trigger does not block the rest. - tests/conftest.py: new sqlite_session / sqlite_session_factory / sqlite_engine / patch_github_transport fixtures. 499 unit tests pass; the only errors in this sandbox are testcontainer setups that need a Docker daemon. Co-authored-by: openhands --- openhands/automation/scheduler.py | 79 ++-- openhands/automation/schemas.py | 320 +++++++++++++++- openhands/automation/utils/__init__.py | 3 + openhands/automation/utils/async_utils.py | 60 +++ tests/conftest.py | 71 ++++ tests/test_async_utils.py | 64 ++++ tests/test_scheduler_create_pending_runs.py | 191 ++++++++++ tests/test_triggers.py | 392 ++++++++++++++++++++ 8 files changed, 1149 insertions(+), 31 deletions(-) create mode 100644 openhands/automation/utils/async_utils.py create mode 100644 tests/test_async_utils.py create mode 100644 tests/test_scheduler_create_pending_runs.py create mode 100644 tests/test_triggers.py diff --git a/openhands/automation/scheduler.py b/openhands/automation/scheduler.py index ac1d5c0..457f37b 100644 --- a/openhands/automation/scheduler.py +++ b/openhands/automation/scheduler.py @@ -12,13 +12,14 @@ import logging from datetime import datetime, timedelta +from pydantic import ValidationError from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from openhands.automation.db import using_sqlite from openhands.automation.models import Automation, AutomationRun -from openhands.automation.utils import is_automation_due, utcnow -from openhands.automation.utils.run import create_pending_run +from openhands.automation.schemas import TriggerAdapter +from openhands.automation.utils import utcnow logger = logging.getLogger("automation.scheduler") @@ -74,6 +75,60 @@ async def _fetch_enabled_automations( return list(result.scalars().all()) +async def _create_pending_runs( + session: AsyncSession, + automations: list[Automation], + now: datetime, +) -> list[AutomationRun]: + """Ask each automation's trigger to create a PENDING run if it's due. + + Each automation's ``trigger`` JSON is parsed into a typed model + (``CronTrigger``/``EventTrigger``/``GithubTrigger``) and its + :meth:`_TriggerBase.create_pending_run` is awaited. A return value of + ``None`` means "not due"; any other return is appended to the result. + + Calls run **sequentially** because they share a single + :class:`AsyncSession` (SQLAlchemy async sessions are not safe for + concurrent use). Triggers that need to fan out external I/O — e.g. + :class:`~openhands.automation.schemas.GithubTrigger` polling multiple + repos — do so internally. + + Per-trigger exceptions are logged and skipped so one bad trigger cannot + starve the rest of the batch. + """ + created: list[AutomationRun] = [] + for automation in automations: + try: + trigger = TriggerAdapter.validate_python(automation.trigger) + except ValidationError: + logger.exception("Invalid trigger config for automation %s", automation.id) + continue + + try: + run = await trigger.create_pending_run(session, automation, now) + except Exception: + logger.exception( + "Trigger %s raised while processing automation %s", + type(trigger).__name__, + automation.id, + ) + continue + + if run is None: + continue + + created.append(run) + logger.info( + "Created pending run: run_id=%s automation_id=%s name=%s trigger_type=%s", + run.id, + automation.id, + automation.name, + automation.trigger.get("type"), + ) + + return created + + async def poll_and_schedule( session_factory: async_sessionmaker[AsyncSession], batch_size: int = DEFAULT_BATCH_SIZE, @@ -117,25 +172,7 @@ async def poll_and_schedule( for automation in automations: automation.last_polled_at = now - due_automations = [a for a in automations if is_automation_due(a, now)] - - for automation in due_automations: - try: - run = await create_pending_run(session, automation) - created_runs.append(run) - logger.info( - "Created pending run: run_id=%s automation_id=%s " - "name=%s schedule=%s", - run.id, - automation.id, - automation.name, - automation.trigger.get("schedule"), - ) - except Exception: - logger.exception( - "Failed to create run for automation %s", - automation.id, - ) + created_runs.extend(await _create_pending_runs(session, automations, now)) # Always commit to release row locks from FOR UPDATE SKIP LOCKED, # even if no runs were created diff --git a/openhands/automation/schemas.py b/openhands/automation/schemas.py index cf7d6db..4802ca2 100644 --- a/openhands/automation/schemas.py +++ b/openhands/automation/schemas.py @@ -1,17 +1,43 @@ """Pydantic request/response schemas for the API.""" +from __future__ import annotations + +import logging import re import uuid from datetime import datetime from enum import StrEnum -from typing import Annotated, Literal +from typing import TYPE_CHECKING, Annotated, Any, Literal +from zoneinfo import ZoneInfo +import httpx from croniter import croniter -from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, field_validator +from pydantic import ( + BaseModel, + ConfigDict, + Discriminator, + Field, + SecretStr, + Tag, + TypeAdapter, + field_serializer, + field_validator, +) from openhands.automation.config import get_config +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncSession + + from openhands.automation.models import Automation, AutomationRun + + +logger = logging.getLogger(__name__) + +_GITHUB_REPO_RE = re.compile(r"^[A-Za-z0-9_.\-]+/[A-Za-z0-9_.\-]+$") + + # Allowed URI schemes for tarball_path (includes internal upload scheme) _TARBALL_SCHEME_RE = re.compile(r"^(s3|gs|https?|oh-internal)://") @@ -37,11 +63,50 @@ def _validate_timeout(v: int | None) -> int | None: return v -class CronTrigger(BaseModel): - """Cron-based trigger configuration.""" +class _TriggerBase(BaseModel): + """Common base for all trigger configurations. + + Subclasses implement :meth:`create_pending_run` to decide — once per + scheduler poll cycle — whether the given automation should fire right now, + and if so, what (optional) event payload to attach to the resulting + ``AutomationRun``. + + Concurrency note: the scheduler invokes this method **sequentially** for + each automation in a batch because they share a single + :class:`~sqlalchemy.ext.asyncio.AsyncSession`, which is not safe for + concurrent use. Triggers should still parallelize their *own* external + I/O (e.g. by using :func:`openhands.automation.utils.async_utils.wait_all` + to fan out HTTP calls across multiple resources). + """ model_config = ConfigDict(extra="forbid") + type: str + + async def create_pending_run( + self, + session: AsyncSession, + automation: Automation, + now: datetime | None = None, + ) -> AutomationRun | None: + """Return a PENDING ``AutomationRun`` if the trigger is due, else None. + + Implementations that decide to fire MUST call + :func:`openhands.automation.utils.run.create_pending_run` (which + bumps ``last_triggered_at``/``last_polled_at`` and appends the run + to ``session``) and may then mutate the returned run — for example + to populate ``event_payload`` with the data that caused the fire. + + Exceptions raised here are logged and treated as "not due" by the + scheduler so that one broken trigger cannot starve the rest of the + batch. + """ + raise NotImplementedError + + +class CronTrigger(_TriggerBase): + """Cron-based trigger configuration.""" + type: Literal["cron"] = "cron" schedule: str = Field(..., description="Cron expression, e.g. '0 9 * * 5'") timezone: str = Field(default="UTC", description="IANA timezone name") @@ -53,8 +118,26 @@ def validate_cron_schedule(cls, v: str) -> str: raise ValueError(f"Invalid cron expression: {v}") return v + async def create_pending_run( + self, + session: AsyncSession, + automation: Automation, + now: datetime | None = None, + ) -> AutomationRun | None: + """Fire if the cron's most recent slot has passed since last trigger.""" + from openhands.automation.utils.cron import ( + is_automation_due as _is_due_cron, + ) + from openhands.automation.utils.run import ( + create_pending_run as _create_run_util, + ) + + if not _is_due_cron(automation, now): + return None + return await _create_run_util(session, automation) + -class EventTrigger(BaseModel): +class EventTrigger(_TriggerBase): """ Event-based trigger configuration. @@ -126,8 +209,6 @@ class EventTrigger(BaseModel): ``` """ - model_config = ConfigDict(extra="forbid") - type: Literal["event"] = "event" source: str = Field( ..., @@ -172,19 +253,230 @@ def event_patterns(self) -> list[str]: return [self.on] return self.on + async def create_pending_run( + self, + session: AsyncSession, # noqa: ARG002 + automation: Automation, # noqa: ARG002 + now: datetime | None = None, # noqa: ARG002 + ) -> AutomationRun | None: + """Event triggers are fired by the webhook router, never by polling.""" + return None + + +class GithubTrigger(_TriggerBase): + """Poll-based trigger that fires when new events appear on GitHub repos. + + On each scheduler poll, the trigger queries the + ``/repos/{owner}/{repo}/events`` endpoint for each configured repository + **concurrently** and collects any event created after the automation's + last fire time (or its ``created_at`` for the very first poll, mirroring + the no-backfill semantics of :class:`CronTrigger`). + + If any matching events are found, :meth:`create_pending_run` creates a + PENDING ``AutomationRun`` and stores the collected events on + ``run.event_payload`` so the run's entrypoint can react to them; otherwise + it returns ``None``. + + Optionally restrict the event types that count using ``event_types`` (e.g. + ``["PushEvent", "PullRequestEvent"]``); when omitted, any event type + triggers a fire. + """ + + type: Literal["github"] = "github" + github_access_token: SecretStr = Field( + ..., + description=( + "GitHub Personal Access Token used to authenticate against the " + "REST API. Authenticated requests have a 5000/hour rate limit " + "compared to 60/hour unauthenticated." + ), + ) + repositories: list[str] = Field( + ..., + min_length=1, + description=( + "Repositories to poll, each as 'owner/name' " + "(e.g. 'All-Hands-AI/OpenHands')." + ), + ) + event_types: list[str] | None = Field( + default=None, + description=( + "Optional allow-list of GitHub event types (e.g. 'PushEvent'). " + "When unset, any event type counts as new activity." + ), + ) + + @field_validator("repositories") + @classmethod + def validate_repositories(cls, v: list[str]) -> list[str]: + cleaned: list[str] = [] + for repo in v: + repo = repo.strip() + if not _GITHUB_REPO_RE.match(repo): + raise ValueError(f"Invalid repository {repo!r}: expected 'owner/name'") + cleaned.append(repo) + return cleaned + + @field_validator("event_types") + @classmethod + def validate_event_types(cls, v: list[str] | None) -> list[str] | None: + if v is None: + return v + cleaned = [t.strip() for t in v if t and t.strip()] + return cleaned or None + + @field_serializer("github_access_token", when_used="always") + def _serialize_token(self, v: SecretStr) -> str: + """Emit the raw secret so the trigger can round-trip through the + JSON column. + + The token must be stored in plain text because the scheduler needs + to read it back later to authenticate against GitHub. ``SecretStr`` + is still useful at the application layer: it guards against + accidental logging via ``repr()`` and string interpolation. Treat + the on-disk JSON as sensitive (same trust level as the rest of the + automation config) — anyone with read access to the database can + see it. + """ + return v.get_secret_value() + + def _build_client(self) -> httpx.AsyncClient: + """Construct an authenticated GitHub REST client.""" + return httpx.AsyncClient( + base_url="https://api.github.com", + headers={ + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", + "User-Agent": "openhands-automation", + "Authorization": ( + f"Bearer {self.github_access_token.get_secret_value()}" + ), + }, + timeout=30.0, + ) + + async def _fetch_new_events( + self, + client: httpx.AsyncClient, + repo: str, + cutoff: datetime, + ) -> list[dict[str, Any]]: + """Return all matching events for ``repo`` newer than ``cutoff``. + + Honours :attr:`event_types`. Errors (HTTP/JSON/non-200) are logged + and treated as "no new events" so a single bad repo doesn't take + down the whole trigger. + """ + try: + resp = await client.get( + f"/repos/{repo}/events", + params={"per_page": 30, "page": 1}, + ) + except httpx.HTTPError as e: + logger.warning("GitHub poll failed for %s: %s", repo, e) + return [] + + if resp.status_code != 200: + logger.warning( + "GitHub poll for %s returned status %s", + repo, + resp.status_code, + ) + return [] + + try: + events = resp.json() + except ValueError: + logger.warning("GitHub poll for %s returned non-JSON body", repo) + return [] + if not isinstance(events, list): + return [] + + allowed: set[str] | None = set(self.event_types) if self.event_types else None + new_events: list[dict[str, Any]] = [] + for ev in events: + if not isinstance(ev, dict): + continue + if allowed is not None and ev.get("type") not in allowed: + continue + created_raw = ev.get("created_at") + if not isinstance(created_raw, str): + continue + try: + created_at = datetime.fromisoformat(created_raw.replace("Z", "+00:00")) + except ValueError: + continue + if created_at.tzinfo is None: + created_at = created_at.replace(tzinfo=ZoneInfo("UTC")) + if created_at > cutoff: + # Tag with repo so downstream code knows where it came from. + tagged = dict(ev) + tagged.setdefault("_repository", repo) + new_events.append(tagged) + return new_events + + async def create_pending_run( + self, + session: AsyncSession, + automation: Automation, + now: datetime | None = None, # noqa: ARG002 + ) -> AutomationRun | None: + """Fire if any configured repo has matching new events. + + On fire, attaches the events that caused the fire to + ``run.event_payload`` as:: + + { + "source": "github_trigger", + "events": [, ...], + } + """ + # Deferred imports avoid circular dependencies at module load time. + from openhands.automation.utils.async_utils import wait_all + from openhands.automation.utils.run import ( + create_pending_run as _create_run_util, + ) + + if not automation.enabled or automation.deleted_at is not None: + return None + + cutoff = automation.last_triggered_at or automation.created_at + if cutoff is None: + return None + if cutoff.tzinfo is None: + cutoff = cutoff.replace(tzinfo=ZoneInfo("UTC")) + + async with self._build_client() as client: + per_repo: list[list[dict[str, Any]]] = await wait_all( + [ + self._fetch_new_events(client, repo, cutoff) + for repo in self.repositories + ], + timeout=None, + ) + + all_events: list[dict[str, Any]] = [ev for batch in per_repo for ev in batch] + if not all_events: + return None + + run = await _create_run_util(session, automation) + run.event_payload = {"source": "github_trigger", "events": all_events} + return run + def _get_trigger_discriminator(v: dict | BaseModel) -> str: """Discriminator function for Pydantic's discriminated union. Returns the trigger type string, which Pydantic uses to select the - correct model (CronTrigger or EventTrigger) from the union. + correct model (CronTrigger, EventTrigger, or GithubTrigger) from the union. Why sentinel instead of raising ValueError: Pydantic discriminator functions must return a string - they cannot raise exceptions. By returning an invalid sentinel value, Pydantic generates a proper ValidationError with context like: "Input tag '__missing_trigger_type__' found using 'type' does not - match any of the expected tags: 'cron', 'event'" + match any of the expected tags: 'cron', 'event', 'github'" This produces a user-friendly 422 response via FastAPI. """ if isinstance(v, dict): @@ -197,10 +489,18 @@ def _get_trigger_discriminator(v: dict | BaseModel) -> str: # Union type for all triggers, using discriminated union Trigger = Annotated[ - Annotated[CronTrigger, Tag("cron")] | Annotated[EventTrigger, Tag("event")], + Annotated[CronTrigger, Tag("cron")] + | Annotated[EventTrigger, Tag("event")] + | Annotated[GithubTrigger, Tag("github")], Discriminator(_get_trigger_discriminator), ] +# Reusable adapter for parsing trigger dicts (e.g. ``automation.trigger`` JSON) +# into the correct ``_TriggerBase`` subclass. +TriggerAdapter: TypeAdapter[CronTrigger | EventTrigger | GithubTrigger] = TypeAdapter( + Trigger +) + class RunStatus(StrEnum): """Status of an automation run (for API responses).""" diff --git a/openhands/automation/utils/__init__.py b/openhands/automation/utils/__init__.py index 5aa4200..98acbaa 100644 --- a/openhands/automation/utils/__init__.py +++ b/openhands/automation/utils/__init__.py @@ -4,6 +4,7 @@ APIKeyError, get_api_key_for_automation_run, ) +from openhands.automation.utils.async_utils import AsyncException, wait_all from openhands.automation.utils.cron import ( get_next_fire_time, get_prev_fire_time, @@ -15,10 +16,12 @@ __all__ = [ "APIKeyError", + "AsyncException", "get_api_key_for_automation_run", "get_next_fire_time", "get_prev_fire_time", "is_automation_due", "log_extra", "utcnow", + "wait_all", ] diff --git a/openhands/automation/utils/async_utils.py b/openhands/automation/utils/async_utils.py new file mode 100644 index 0000000..4783aef --- /dev/null +++ b/openhands/automation/utils/async_utils.py @@ -0,0 +1,60 @@ +"""Async helpers for running coroutines concurrently. + +Adapted from +https://github.com/OpenHands/OpenHands/blob/main/openhands/app_server/utils/async_utils.py +""" + +from __future__ import annotations + +import asyncio +from collections.abc import Coroutine, Iterable +from typing import Any + + +GENERAL_TIMEOUT: int = 15 + + +class AsyncException(Exception): + """Raised by ``wait_all`` when more than one task raised an exception.""" + + def __init__(self, exceptions: list[BaseException]) -> None: + self.exceptions = exceptions + super().__init__("\n".join(str(e) for e in exceptions)) + + +async def wait_all( + iterable: Iterable[Coroutine[Any, Any, Any]], + timeout: float | None = GENERAL_TIMEOUT, +) -> list[Any]: + """Run the given coroutines concurrently and wait for them all to finish. + + Returns the results in the original order. If a single task raised an + exception it is re-raised. If multiple tasks raised, an + :class:`AsyncException` containing all of them is raised. If the timeout + elapses any still-pending tasks are cancelled and ``asyncio.TimeoutError`` + is raised. + """ + tasks = [asyncio.create_task(c) for c in iterable] + if not tasks: + return [] + + _, pending = await asyncio.wait(tasks, timeout=timeout) + if pending: + for task in pending: + task.cancel() + raise TimeoutError() + + results: list[Any] = [] + errors: list[BaseException] = [] + for task in tasks: + try: + results.append(task.result()) + except Exception as e: # noqa: BLE001 - propagated below + errors.append(e) + results.append(None) + + if errors: + if len(errors) == 1: + raise errors[0] + raise AsyncException(errors) + return results diff --git a/tests/conftest.py b/tests/conftest.py index 6099085..593b729 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -203,3 +203,74 @@ def mock_settings(): service_key="test-service-key", base_url="http://localhost:8000", ) + + +# --------------------------------------------------------------------------- +# Lightweight SQLite fixtures for trigger / scheduler-helper unit tests +# --------------------------------------------------------------------------- +# These bypass testcontainers/Docker for tests that only need a real session +# to write/read AutomationRun rows. The scheduler code uses ``using_sqlite()`` +# to gate Postgres-only features (``FOR UPDATE SKIP LOCKED``), so SQLite is a +# valid backend here. + + +@pytest.fixture +async def sqlite_engine(): + engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield engine + await engine.dispose() + + +@pytest.fixture +async def sqlite_session_factory(sqlite_engine): + return async_sessionmaker( + sqlite_engine, class_=AsyncSession, expire_on_commit=False + ) + + +@pytest.fixture +async def sqlite_session(sqlite_session_factory): + async with sqlite_session_factory() as session: + yield session + + +@pytest.fixture +def patch_github_transport(monkeypatch): + """Inject an ``httpx.MockTransport`` into ``GithubTrigger._build_client``. + + Returns a callable: given a responder ``(httpx.Request) -> httpx.Response``, + installs it and returns the list of requests the trigger ends up issuing. + """ + import httpx + + from openhands.automation.schemas import GithubTrigger + + def install(responder): + seen: list[httpx.Request] = [] + + def capture(request: httpx.Request) -> httpx.Response: + seen.append(request) + return responder(request) + + transport = httpx.MockTransport(capture) + + def _patched(self: GithubTrigger) -> httpx.AsyncClient: + return httpx.AsyncClient( + base_url="https://api.github.com", + headers={ + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": "2022-11-28", + "User-Agent": "openhands-automation", + "Authorization": ( + f"Bearer {self.github_access_token.get_secret_value()}" + ), + }, + transport=transport, + ) + + monkeypatch.setattr(GithubTrigger, "_build_client", _patched) + return seen + + return install diff --git a/tests/test_async_utils.py b/tests/test_async_utils.py new file mode 100644 index 0000000..9539ab9 --- /dev/null +++ b/tests/test_async_utils.py @@ -0,0 +1,64 @@ +"""Tests for openhands.automation.utils.async_utils.""" + +import asyncio + +import pytest + +from openhands.automation.utils.async_utils import AsyncException, wait_all + + +class TestWaitAll: + async def test_empty_iterable_returns_empty_list(self): + assert await wait_all([]) == [] + + async def test_results_in_original_order(self): + async def producer(value: int, delay: float) -> int: + await asyncio.sleep(delay) + return value + + # The slowest coroutine is first; result order must still match input. + results = await wait_all( + [producer(1, 0.03), producer(2, 0.0), producer(3, 0.01)] + ) + assert results == [1, 2, 3] + + async def test_runs_concurrently(self): + # Three 50ms sleeps must complete well under 150ms (serial baseline). + async def sleeper() -> int: + await asyncio.sleep(0.05) + return 1 + + loop = asyncio.get_running_loop() + start = loop.time() + results = await wait_all([sleeper(), sleeper(), sleeper()]) + elapsed = loop.time() - start + assert results == [1, 1, 1] + assert elapsed < 0.13 + + async def test_single_exception_propagates(self): + async def boom() -> None: + raise RuntimeError("kapow") + + async def ok() -> int: + return 42 + + with pytest.raises(RuntimeError, match="kapow"): + await wait_all([ok(), boom()]) + + async def test_multiple_exceptions_wrapped(self): + async def one() -> None: + raise ValueError("a") + + async def two() -> None: + raise ValueError("b") + + with pytest.raises(AsyncException) as exc: + await wait_all([one(), two()]) + assert len(exc.value.exceptions) == 2 + + async def test_timeout_cancels_pending(self): + async def slow() -> None: + await asyncio.sleep(1) + + with pytest.raises(asyncio.TimeoutError): + await wait_all([slow()], timeout=0.05) diff --git a/tests/test_scheduler_create_pending_runs.py b/tests/test_scheduler_create_pending_runs.py new file mode 100644 index 0000000..a2a3aca --- /dev/null +++ b/tests/test_scheduler_create_pending_runs.py @@ -0,0 +1,191 @@ +"""Tests for ``scheduler._create_pending_runs``. + +Exercises the scheduler's per-batch trigger dispatch loop against a real +(SQLite in-memory) session, validating the contract documented on +``_TriggerBase.create_pending_run``: return None ⇒ no run created, return +``AutomationRun`` ⇒ appended to the result list (and persisted by the +trigger itself). + +Only the GitHub HTTP transport is replaced; everything else is real code. +""" + +from __future__ import annotations + +import uuid +from datetime import UTC, datetime, timedelta + +import httpx +import pytest +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from openhands.automation.models import Automation, AutomationRun +from openhands.automation.scheduler import _create_pending_runs + + +TEST_USER_ID = uuid.UUID("12345678-1234-5678-1234-567812345678") +TEST_ORG_ID = uuid.UUID("87654321-4321-8765-4321-876543218765") + + +async def _make( + session: AsyncSession, + *, + trigger: dict, + enabled: bool = True, + created_at: datetime | None = None, + last_triggered_at: datetime | None = None, +) -> Automation: + automation = Automation( + id=uuid.uuid4(), + user_id=TEST_USER_ID, + org_id=TEST_ORG_ID, + name="Test", + trigger=trigger, + tarball_path="s3://bucket/code.tar.gz", + entrypoint="uv run main.py", + enabled=enabled, + deleted_at=None, + created_at=created_at or datetime(2026, 1, 1, tzinfo=UTC), + last_triggered_at=last_triggered_at, + ) + session.add(automation) + await session.flush() + return automation + + +class TestCreatePendingRuns: + async def test_empty_input(self, sqlite_session): + assert ( + await _create_pending_runs( + sqlite_session, [], datetime(2026, 1, 1, tzinfo=UTC) + ) + == [] + ) + + async def test_mixed_triggers_only_due_ones_create_runs( + self, sqlite_session, patch_github_transport + ): + cutoff = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + now = cutoff + timedelta(minutes=5) + + # Github API: events for `org/has-events`, nothing for `org/quiet`. + def responder(req: httpx.Request) -> httpx.Response: + if "has-events" in req.url.path: + return httpx.Response( + 200, + json=[ + { + "id": "1", + "type": "PushEvent", + "created_at": (cutoff + timedelta(minutes=1)).strftime( + "%Y-%m-%dT%H:%M:%SZ" + ), + } + ], + ) + return httpx.Response(200, json=[]) + + patch_github_transport(responder) + + cron_due = await _make( + sqlite_session, + trigger={"type": "cron", "schedule": "0,30 * * * *", "timezone": "UTC"}, + created_at=datetime(2026, 3, 15, 10, 25, tzinfo=UTC), + ) + cron_not_due = await _make( + sqlite_session, + trigger={"type": "cron", "schedule": "0,30 * * * *", "timezone": "UTC"}, + created_at=datetime(2026, 3, 15, 12, 4, tzinfo=UTC), + ) + event_not_due = await _make( + sqlite_session, + trigger={ + "type": "event", + "source": "github", + "on": "pull_request.opened", + }, + ) + gh_due = await _make( + sqlite_session, + trigger={ + "type": "github", + "github_access_token": "ghp_xxx", + "repositories": ["org/has-events"], + }, + last_triggered_at=cutoff, + ) + gh_not_due = await _make( + sqlite_session, + trigger={ + "type": "github", + "github_access_token": "ghp_xxx", + "repositories": ["org/quiet"], + }, + last_triggered_at=cutoff, + ) + invalid = await _make(sqlite_session, trigger={"type": "totally-not-a-trigger"}) + + runs = await _create_pending_runs( + sqlite_session, + [cron_due, cron_not_due, event_not_due, gh_due, gh_not_due, invalid], + now, + ) + await sqlite_session.commit() + + run_automation_ids = {r.automation_id for r in runs} + assert run_automation_ids == {cron_due.id, gh_due.id} + + # GitHub run carries its event payload; cron run does not. + runs_by_aid = {r.automation_id: r for r in runs} + assert runs_by_aid[gh_due.id].event_payload is not None + assert runs_by_aid[gh_due.id].event_payload["source"] == "github_trigger" + assert len(runs_by_aid[gh_due.id].event_payload["events"]) == 1 + assert runs_by_aid[cron_due.id].event_payload is None + + # And the runs are durably persisted. + persisted_ids = { + r.automation_id + for r in (await sqlite_session.execute(select(AutomationRun))) + .scalars() + .all() + } + assert persisted_ids == {cron_due.id, gh_due.id} + + async def test_failing_trigger_does_not_block_others( + self, sqlite_session, patch_github_transport + ): + """A trigger raising must NOT stop other triggers from creating runs.""" + cutoff = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + now = cutoff + timedelta(minutes=5) + + # GitHub transport raises a connection error → the trigger's internal + # error handling logs and returns no events → no run, no exception. + # Then a SECOND github trigger explodes outright by raising from the + # transport responder again — both should be tolerated. + patch_github_transport( + lambda req: (_ for _ in ()).throw(httpx.ConnectError("no route")) + ) + + cron_due = await _make( + sqlite_session, + trigger={"type": "cron", "schedule": "0,30 * * * *", "timezone": "UTC"}, + created_at=datetime(2026, 3, 15, 10, 25, tzinfo=UTC), + ) + gh_broken = await _make( + sqlite_session, + trigger={ + "type": "github", + "github_access_token": "ghp_xxx", + "repositories": ["org/exploding"], + }, + last_triggered_at=cutoff, + ) + + runs = await _create_pending_runs(sqlite_session, [cron_due, gh_broken], now) + await sqlite_session.commit() + + assert [r.automation_id for r in runs] == [cron_due.id] + + +if __name__ == "__main__": # pragma: no cover + pytest.main([__file__]) diff --git a/tests/test_triggers.py b/tests/test_triggers.py new file mode 100644 index 0000000..6c80603 --- /dev/null +++ b/tests/test_triggers.py @@ -0,0 +1,392 @@ +"""Tests for the trigger schemas and their ``create_pending_run`` methods. + +These tests exercise the real :class:`CronTrigger`/:class:`EventTrigger`/ +:class:`GithubTrigger` code paths against a real (SQLite in-memory) session, +so :func:`openhands.automation.utils.run.create_pending_run` runs against a +real database. Only the GitHub HTTP transport is replaced via +``httpx.MockTransport`` (see the ``patch_github_transport`` fixture); the +``GithubTrigger`` code — including header/auth construction, response +parsing, timestamp filtering, and parallel repo fan-out — is untouched. +""" + +from __future__ import annotations + +import uuid +from datetime import UTC, datetime, timedelta +from typing import Any + +import pytest +from pydantic import SecretStr, ValidationError +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from openhands.automation.models import Automation, AutomationRun +from openhands.automation.schemas import ( + EventTrigger, + GithubTrigger, + TriggerAdapter, +) + + +# --- helpers ---------------------------------------------------------------- + + +TEST_USER_ID = uuid.UUID("12345678-1234-5678-1234-567812345678") +TEST_ORG_ID = uuid.UUID("87654321-4321-8765-4321-876543218765") + + +async def _make_automation( + session: AsyncSession, + *, + trigger: dict[str, Any], + enabled: bool = True, + deleted_at: datetime | None = None, + created_at: datetime | None = None, + last_triggered_at: datetime | None = None, +) -> Automation: + automation = Automation( + id=uuid.uuid4(), + user_id=TEST_USER_ID, + org_id=TEST_ORG_ID, + name="Test", + trigger=trigger, + tarball_path="s3://bucket/code.tar.gz", + entrypoint="uv run main.py", + enabled=enabled, + deleted_at=deleted_at, + created_at=created_at or datetime(2026, 1, 1, tzinfo=UTC), + last_triggered_at=last_triggered_at, + ) + session.add(automation) + await session.flush() + return automation + + +def _gh_event( + event_id: int, created_at: datetime, event_type: str = "PushEvent" +) -> dict[str, Any]: + return { + "id": str(event_id), + "type": event_type, + "created_at": created_at.strftime("%Y-%m-%dT%H:%M:%SZ"), + } + + +# --------------------------------------------------------------------------- +# CronTrigger.create_pending_run +# --------------------------------------------------------------------------- + + +class TestCronTriggerCreatesRun: + async def test_creates_run_when_due(self, sqlite_session): + trigger_cfg = { + "type": "cron", + "schedule": "0,30 * * * *", + "timezone": "UTC", + } + # Every 30 min; created 10 min before the 10:30 fire window. + automation = await _make_automation( + sqlite_session, + trigger=trigger_cfg, + created_at=datetime(2026, 3, 15, 10, 25, tzinfo=UTC), + ) + trigger = TriggerAdapter.validate_python(trigger_cfg) + now = datetime(2026, 3, 15, 10, 35, tzinfo=UTC) + + run = await trigger.create_pending_run(sqlite_session, automation, now) + await sqlite_session.commit() + + assert run is not None + assert run.automation_id == automation.id + # Round-trip via DB to confirm the row was persisted. + from_db = ( + ( + await sqlite_session.execute( + select(AutomationRun).where(AutomationRun.id == run.id) + ) + ) + .scalars() + .first() + ) + assert from_db is not None + # Cron triggers don't attach an event payload. + assert from_db.event_payload is None + # The util bumped last_triggered_at on the automation. + assert automation.last_triggered_at is not None + + async def test_returns_none_when_not_due(self, sqlite_session): + trigger_cfg = { + "type": "cron", + "schedule": "0,30 * * * *", + "timezone": "UTC", + } + # Automation was JUST created — the most recent fire (10:30) is BEFORE + # creation (10:34), so it shouldn't fire yet. + automation = await _make_automation( + sqlite_session, + trigger=trigger_cfg, + created_at=datetime(2026, 3, 15, 10, 34, tzinfo=UTC), + ) + trigger = TriggerAdapter.validate_python(trigger_cfg) + now = datetime(2026, 3, 15, 10, 35, tzinfo=UTC) + + assert await trigger.create_pending_run(sqlite_session, automation, now) is None + + async def test_returns_none_when_disabled(self, sqlite_session): + trigger_cfg = {"type": "cron", "schedule": "* * * * *", "timezone": "UTC"} + automation = await _make_automation( + sqlite_session, trigger=trigger_cfg, enabled=False + ) + trigger = TriggerAdapter.validate_python(trigger_cfg) + assert await trigger.create_pending_run(sqlite_session, automation) is None + + +# --------------------------------------------------------------------------- +# EventTrigger.create_pending_run +# --------------------------------------------------------------------------- + + +class TestEventTriggerNeverFires: + async def test_polling_never_creates_a_run(self, sqlite_session): + trigger_cfg = { + "type": "event", + "source": "github", + "on": "pull_request.opened", + } + automation = await _make_automation(sqlite_session, trigger=trigger_cfg) + trigger = EventTrigger.model_validate(trigger_cfg) + assert await trigger.create_pending_run(sqlite_session, automation) is None + + +# --------------------------------------------------------------------------- +# GithubTrigger — validation +# --------------------------------------------------------------------------- + + +class TestGithubTriggerValidation: + def test_minimum_valid_config(self): + trigger = GithubTrigger( + github_access_token=SecretStr("ghp_xxx"), + repositories=["All-Hands-AI/OpenHands"], + ) + assert trigger.type == "github" + assert trigger.repositories == ["All-Hands-AI/OpenHands"] + # Secret never leaks via repr. + assert "ghp_xxx" not in repr(trigger) + assert trigger.github_access_token.get_secret_value() == "ghp_xxx" + + def test_invalid_repository_format_rejected(self): + with pytest.raises(ValidationError, match="Invalid repository"): + GithubTrigger( + github_access_token=SecretStr("ghp_xxx"), + repositories=["not-a-repo"], + ) + + def test_empty_repositories_rejected(self): + with pytest.raises(ValidationError): + GithubTrigger( + github_access_token=SecretStr("ghp_xxx"), + repositories=[], + ) + + def test_discriminated_union_dispatches_to_github(self): + parsed = TriggerAdapter.validate_python( + { + "type": "github", + "github_access_token": "ghp_yyy", + "repositories": ["foo/bar"], + } + ) + assert isinstance(parsed, GithubTrigger) + + +# --------------------------------------------------------------------------- +# GithubTrigger.create_pending_run +# --------------------------------------------------------------------------- + + +class TestGithubTriggerCreatesRun: + async def test_fires_and_attaches_events_to_payload( + self, sqlite_session, patch_github_transport + ): + cutoff = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + fresh = _gh_event(2, cutoff + timedelta(minutes=5)) + seen = patch_github_transport( + lambda req: __import__("httpx").Response(200, json=[fresh]) + ) + + trigger = GithubTrigger( + github_access_token=SecretStr("ghp_xxx"), + repositories=["foo/bar"], + ) + automation = await _make_automation( + sqlite_session, + trigger=trigger.model_dump(mode="python"), + last_triggered_at=cutoff, + ) + + run = await trigger.create_pending_run(sqlite_session, automation) + await sqlite_session.commit() + + assert run is not None + assert run.automation_id == automation.id + assert run.event_payload is not None + assert run.event_payload["source"] == "github_trigger" + assert len(run.event_payload["events"]) == 1 + event = run.event_payload["events"][0] + assert event["id"] == "2" + assert event["type"] == "PushEvent" + # The trigger tags each event with the source repo. + assert event["_repository"] == "foo/bar" + + # Sanity-check the outgoing HTTP request. + assert len(seen) == 1 + assert seen[0].url.path == "/repos/foo/bar/events" + assert seen[0].headers["Authorization"] == "Bearer ghp_xxx" + + async def test_returns_none_when_no_new_events( + self, sqlite_session, patch_github_transport + ): + cutoff = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + # All events are older than cutoff. + stale = _gh_event(1, cutoff - timedelta(minutes=10)) + patch_github_transport( + lambda req: __import__("httpx").Response(200, json=[stale]) + ) + + trigger = GithubTrigger( + github_access_token=SecretStr("ghp_xxx"), + repositories=["foo/bar"], + ) + automation = await _make_automation( + sqlite_session, + trigger=trigger.model_dump(mode="python"), + last_triggered_at=cutoff, + ) + + assert await trigger.create_pending_run(sqlite_session, automation) is None + # No run was persisted. + runs = (await sqlite_session.execute(select(AutomationRun))).scalars().all() + assert runs == [] + + async def test_event_type_filter_excludes_non_matching( + self, sqlite_session, patch_github_transport + ): + cutoff = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + events = [_gh_event(2, cutoff + timedelta(minutes=5), event_type="IssuesEvent")] + patch_github_transport( + lambda req: __import__("httpx").Response(200, json=events) + ) + + trigger = GithubTrigger( + github_access_token=SecretStr("ghp_xxx"), + repositories=["foo/bar"], + event_types=["PushEvent"], + ) + automation = await _make_automation( + sqlite_session, + trigger=trigger.model_dump(mode="python"), + last_triggered_at=cutoff, + ) + assert await trigger.create_pending_run(sqlite_session, automation) is None + + async def test_uses_created_at_when_never_triggered( + self, sqlite_session, patch_github_transport + ): + created = datetime(2026, 3, 15, 9, 0, 0, tzinfo=UTC) + events = [_gh_event(3, created + timedelta(hours=1))] + patch_github_transport( + lambda req: __import__("httpx").Response(200, json=events) + ) + + trigger = GithubTrigger( + github_access_token=SecretStr("ghp_xxx"), + repositories=["foo/bar"], + ) + automation = await _make_automation( + sqlite_session, + trigger=trigger.model_dump(mode="python"), + created_at=created, + last_triggered_at=None, + ) + + run = await trigger.create_pending_run(sqlite_session, automation) + assert run is not None + assert len(run.event_payload["events"]) == 1 + + async def test_disabled_short_circuits_before_http( + self, sqlite_session, patch_github_transport + ): + seen = patch_github_transport( + lambda req: __import__("httpx").Response(200, json=[]) + ) + trigger = GithubTrigger( + github_access_token=SecretStr("ghp_xxx"), + repositories=["foo/bar"], + ) + automation = await _make_automation( + sqlite_session, + trigger=trigger.model_dump(mode="python"), + enabled=False, + ) + assert await trigger.create_pending_run(sqlite_session, automation) is None + assert seen == [] # Must short-circuit BEFORE hitting the network. + + async def test_non_200_response_is_not_due( + self, sqlite_session, patch_github_transport + ): + cutoff = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + patch_github_transport( + lambda req: __import__("httpx").Response( + 403, json={"message": "rate limit"} + ) + ) + trigger = GithubTrigger( + github_access_token=SecretStr("ghp_xxx"), + repositories=["foo/bar"], + ) + automation = await _make_automation( + sqlite_session, + trigger=trigger.model_dump(mode="python"), + last_triggered_at=cutoff, + ) + assert await trigger.create_pending_run(sqlite_session, automation) is None + + async def test_collects_events_across_repos( + self, sqlite_session, patch_github_transport + ): + cutoff = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + import httpx + + def responder(req: httpx.Request) -> httpx.Response: + if "foo/bar" in req.url.path: + return httpx.Response( + 200, + json=[_gh_event(1, cutoff + timedelta(minutes=1))], + ) + return httpx.Response( + 200, + json=[_gh_event(2, cutoff + timedelta(minutes=2))], + ) + + seen = patch_github_transport(responder) + trigger = GithubTrigger( + github_access_token=SecretStr("ghp_xxx"), + repositories=["foo/bar", "foo/baz"], + ) + automation = await _make_automation( + sqlite_session, + trigger=trigger.model_dump(mode="python"), + last_triggered_at=cutoff, + ) + + run = await trigger.create_pending_run(sqlite_session, automation) + assert run is not None + # Both repos contributed an event. + events = run.event_payload["events"] + assert len(events) == 2 + repos = {ev["_repository"] for ev in events} + assert repos == {"foo/bar", "foo/baz"} + # Both repos were polled (parallel fan-out). + polled = {r.url.path for r in seen} + assert polled == {"/repos/foo/bar/events", "/repos/foo/baz/events"} From 71cd45c1650a94cdde5a6422cd7239e76c69cdfb Mon Sep 17 00:00:00 2001 From: openhands Date: Fri, 15 May 2026 20:25:56 +0000 Subject: [PATCH 2/3] refactor(GithubTrigger): replace event_types allow-list with JMESPath event_filter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Swaps the coarse event_types: list[str] allow-list for a single event_filter: str | None JMESPath expression. The expression is evaluated against each GitHub event JSON; an event is kept when the expression returns a truthy value, and the trigger fires only if at least one event matches. This is much more expressive — callers can filter on nested fields such as payload.action, payload.ref, actor.login, etc., without us having to expose each as a dedicated field. Examples: type == 'PushEvent' type == 'PullRequestEvent' && payload.action == 'opened' type == 'PushEvent' && payload.ref == 'refs/heads/main' Implementation notes: - The expression is compiled at field-validation time so syntax errors surface immediately as 422s when the trigger is POSTed, not at the next poll cycle. - The compiled parser is cached lazily on the instance (`_compiled_filter` on `__dict__`) so each poll only pays the parse cost once. - The persisted JSON keeps the raw string — round-trip through the AutomationsTrigger JSON column is unchanged. - JMESPath evaluation errors on a single event are logged and that event is dropped (treated as non-matching); they cannot kill the poll. Tests ----- - Removed the old `event_types` test. - Added five new tests: keep/drop based on filter result, nested `payload.action` matching, valid expression accepted, invalid expression rejected at validation, whitespace-only filter normalised to None. - All 29 trigger / async-util tests pass; broader 124-test unit sweep is green. Co-authored-by: openhands --- openhands/automation/schemas.py | 86 +++++++++++++++++++++----- tests/test_triggers.py | 104 +++++++++++++++++++++++++++++++- 2 files changed, 172 insertions(+), 18 deletions(-) diff --git a/openhands/automation/schemas.py b/openhands/automation/schemas.py index 4802ca2..46dd7e4 100644 --- a/openhands/automation/schemas.py +++ b/openhands/automation/schemas.py @@ -11,7 +11,9 @@ from zoneinfo import ZoneInfo import httpx +import jmespath from croniter import croniter +from jmespath.exceptions import JMESPathError from pydantic import ( BaseModel, ConfigDict, @@ -277,11 +279,23 @@ class GithubTrigger(_TriggerBase): ``run.event_payload`` so the run's entrypoint can react to them; otherwise it returns ``None``. - Optionally restrict the event types that count using ``event_types`` (e.g. - ``["PushEvent", "PullRequestEvent"]``); when omitted, any event type - triggers a fire. + Optionally narrow which events fire the trigger using + :attr:`event_filter` — a `JMESPath `_ expression + evaluated against each event. An event is kept when the expression + returns a truthy value. Examples:: + + type == 'PushEvent' + type == 'PullRequestEvent' && payload.action == 'opened' + type == 'PushEvent' && payload.ref == 'refs/heads/main' + + When ``event_filter`` is omitted, every event newer than the cutoff + fires the trigger. """ + # Tell Pydantic that the cached compiled JMESPath object (a non-Pydantic + # type) is acceptable on the model; we expose it via a property. + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + type: Literal["github"] = "github" github_access_token: SecretStr = Field( ..., @@ -299,11 +313,15 @@ class GithubTrigger(_TriggerBase): "(e.g. 'All-Hands-AI/OpenHands')." ), ) - event_types: list[str] | None = Field( + event_filter: str | None = Field( default=None, description=( - "Optional allow-list of GitHub event types (e.g. 'PushEvent'). " - "When unset, any event type counts as new activity." + "Optional JMESPath expression evaluated against each GitHub event " + "(see https://jmespath.org/). An event matches when the expression " + "returns a truthy value; non-matching events are dropped. The " + "trigger fires only if at least one event matches. Examples: " + "`type == 'PushEvent'`, " + "`type == 'PullRequestEvent' && payload.action == 'opened'`." ), ) @@ -318,13 +336,37 @@ def validate_repositories(cls, v: list[str]) -> list[str]: cleaned.append(repo) return cleaned - @field_validator("event_types") + @field_validator("event_filter") @classmethod - def validate_event_types(cls, v: list[str] | None) -> list[str] | None: + def validate_event_filter(cls, v: str | None) -> str | None: + """Parse the JMESPath at validation time so bad syntax fails fast. + + The compiled expression itself is *not* stored on the field (so the + JSON round-trip is clean) — it's cached lazily on first use via + :attr:`_compiled_event_filter`. + """ if v is None: return v - cleaned = [t.strip() for t in v if t and t.strip()] - return cleaned or None + expr = v.strip() + if not expr: + return None + try: + jmespath.compile(expr) + except JMESPathError as e: + raise ValueError(f"Invalid JMESPath expression: {e}") from e + return expr + + @property + def _compiled_event_filter(self) -> jmespath.parser.ParsedResult | None: + """Compile and cache the JMESPath expression on first use.""" + if self.event_filter is None: + return None + cached = self.__dict__.get("_compiled_filter") + if cached is None: + # Already validated in ``validate_event_filter`` — won't raise. + cached = jmespath.compile(self.event_filter) + self.__dict__["_compiled_filter"] = cached + return cached @field_serializer("github_access_token", when_used="always") def _serialize_token(self, v: SecretStr) -> str: @@ -364,9 +406,12 @@ async def _fetch_new_events( ) -> list[dict[str, Any]]: """Return all matching events for ``repo`` newer than ``cutoff``. - Honours :attr:`event_types`. Errors (HTTP/JSON/non-200) are logged - and treated as "no new events" so a single bad repo doesn't take - down the whole trigger. + Applies :attr:`event_filter` (a JMESPath expression) if set: each + event is kept only when the expression evaluates to a truthy value. + Errors (HTTP/JSON/non-200) are logged and treated as "no new events" + so a single bad repo doesn't take down the whole trigger. A + JMESPath evaluation error on an individual event is logged and the + event is dropped (treated as non-matching). """ try: resp = await client.get( @@ -393,13 +438,22 @@ async def _fetch_new_events( if not isinstance(events, list): return [] - allowed: set[str] | None = set(self.event_types) if self.event_types else None + compiled = self._compiled_event_filter new_events: list[dict[str, Any]] = [] for ev in events: if not isinstance(ev, dict): continue - if allowed is not None and ev.get("type") not in allowed: - continue + if compiled is not None: + try: + if not compiled.search(ev): + continue + except JMESPathError as e: + logger.warning( + "JMESPath evaluation failed for event in %s: %s", + repo, + e, + ) + continue created_raw = ev.get("created_at") if not isinstance(created_raw, str): continue diff --git a/tests/test_triggers.py b/tests/test_triggers.py index 6c80603..d7512e6 100644 --- a/tests/test_triggers.py +++ b/tests/test_triggers.py @@ -189,6 +189,34 @@ def test_empty_repositories_rejected(self): repositories=[], ) + def test_valid_jmespath_event_filter_is_accepted(self): + trigger = GithubTrigger( + github_access_token=SecretStr("ghp_xxx"), + repositories=["foo/bar"], + event_filter="type == 'PushEvent' && payload.ref == 'refs/heads/main'", + ) + assert ( + trigger.event_filter + == "type == 'PushEvent' && payload.ref == 'refs/heads/main'" + ) + + def test_invalid_jmespath_event_filter_rejected(self): + with pytest.raises(ValidationError, match="Invalid JMESPath expression"): + GithubTrigger( + github_access_token=SecretStr("ghp_xxx"), + repositories=["foo/bar"], + event_filter="this is not valid jmespath @@@", + ) + + def test_empty_event_filter_becomes_none(self): + # Whitespace-only filter shouldn't error and shouldn't be retained. + trigger = GithubTrigger( + github_access_token=SecretStr("ghp_xxx"), + repositories=["foo/bar"], + event_filter=" ", + ) + assert trigger.event_filter is None + def test_discriminated_union_dispatches_to_github(self): parsed = TriggerAdapter.validate_python( { @@ -269,9 +297,10 @@ async def test_returns_none_when_no_new_events( runs = (await sqlite_session.execute(select(AutomationRun))).scalars().all() assert runs == [] - async def test_event_type_filter_excludes_non_matching( + async def test_jmespath_filter_excludes_non_matching( self, sqlite_session, patch_github_transport ): + """``event_filter`` drops events whose JMESPath expression is falsy.""" cutoff = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) events = [_gh_event(2, cutoff + timedelta(minutes=5), event_type="IssuesEvent")] patch_github_transport( @@ -281,7 +310,7 @@ async def test_event_type_filter_excludes_non_matching( trigger = GithubTrigger( github_access_token=SecretStr("ghp_xxx"), repositories=["foo/bar"], - event_types=["PushEvent"], + event_filter="type == 'PushEvent'", ) automation = await _make_automation( sqlite_session, @@ -290,6 +319,77 @@ async def test_event_type_filter_excludes_non_matching( ) assert await trigger.create_pending_run(sqlite_session, automation) is None + async def test_jmespath_filter_keeps_matching( + self, sqlite_session, patch_github_transport + ): + """A truthy JMESPath result keeps the event and fires the trigger.""" + cutoff = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + events = [ + _gh_event(7, cutoff + timedelta(minutes=1), event_type="IssuesEvent"), + _gh_event(8, cutoff + timedelta(minutes=2), event_type="PushEvent"), + ] + patch_github_transport( + lambda req: __import__("httpx").Response(200, json=events) + ) + + trigger = GithubTrigger( + github_access_token=SecretStr("ghp_xxx"), + repositories=["foo/bar"], + event_filter="type == 'PushEvent'", + ) + automation = await _make_automation( + sqlite_session, + trigger=trigger.model_dump(mode="python"), + last_triggered_at=cutoff, + ) + run = await trigger.create_pending_run(sqlite_session, automation) + assert run is not None + # Only the PushEvent survived the filter. + kept = run.event_payload["events"] + assert [e["id"] for e in kept] == ["8"] + + async def test_jmespath_filter_supports_nested_payload( + self, sqlite_session, patch_github_transport + ): + """JMESPath can match against the nested ``payload`` object.""" + cutoff = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + events = [ + { + "id": "9", + "type": "PullRequestEvent", + "created_at": (cutoff + timedelta(minutes=1)).strftime( + "%Y-%m-%dT%H:%M:%SZ" + ), + "payload": {"action": "closed"}, + }, + { + "id": "10", + "type": "PullRequestEvent", + "created_at": (cutoff + timedelta(minutes=2)).strftime( + "%Y-%m-%dT%H:%M:%SZ" + ), + "payload": {"action": "opened"}, + }, + ] + patch_github_transport( + lambda req: __import__("httpx").Response(200, json=events) + ) + + trigger = GithubTrigger( + github_access_token=SecretStr("ghp_xxx"), + repositories=["foo/bar"], + event_filter=("type == 'PullRequestEvent' && payload.action == 'opened'"), + ) + automation = await _make_automation( + sqlite_session, + trigger=trigger.model_dump(mode="python"), + last_triggered_at=cutoff, + ) + run = await trigger.create_pending_run(sqlite_session, automation) + assert run is not None + kept = run.event_payload["events"] + assert [e["id"] for e in kept] == ["10"] + async def test_uses_created_at_when_never_triggered( self, sqlite_session, patch_github_transport ): From e27e1202dac91bbe8956495b4de1bc3547d70dab Mon Sep 17 00:00:00 2001 From: openhands Date: Fri, 15 May 2026 20:37:56 +0000 Subject: [PATCH 3/3] feat: add SlackTrigger and factor shared polling logic into _PollingTriggerBase MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a new poll-based trigger that fires when configured Slack channels receive new messages. Like GithubTrigger, it polls the upstream service in parallel and exposes a JMESPath event_filter. Refactor -------- Introduces _PollingTriggerBase, a private subclass of _TriggerBase shared by GithubTrigger and SlackTrigger. The base class owns the bits that are identical across both: - The event_filter field + JMESPath compile-at-validation logic. - The cached _compiled_event_filter property. - _item_matches_filter() — runtime JMESPath evaluation with per-item error handling. - create_pending_run() — enabled/deleted/cutoff guards, async client lifecycle, and persisting the run with the right event_payload envelope. Concrete subclasses now only implement: - _PAYLOAD_SOURCE / _PAYLOAD_KEY (ClassVars for the payload envelope). - _build_client() — auth headers, base URL. - _fetch_all_new(client, cutoff) — typically a wait_all() fan-out. This drops GithubTrigger from ~250 lines to ~110 with no behaviour change (all 24 existing trigger tests still pass). SlackTrigger ------------ - type='slack', slack_token (SecretStr, xoxp-/xoxb-), channels (list of Slack channel IDs, validated against ^[A-Z][A-Z0-9]+$). - Polls /conversations.history with oldest= as the high-water mark, derived from automation.last_triggered_at (or created_at on first poll — no backfill). - Honours Slack's quirky error model: HTTP 200 always for app errors, so we check body['ok'] and log body['error']. HTTP 429 logs Retry-After and treats the cycle as 'no messages'. - Sorts newest-first responses into chronological order before delivery. - Does defense-in-depth client-side cutoff comparison (ts_unix > cutoff) even though Slack's 'oldest' is server-side exclusive. - Tags each message with '_channel' so downstream code knows the origin. - Payload: {source: 'slack_trigger', messages: [...]}. Reference: https://github.com/tofarr/polling-experiment/blob/main/scripts/poll_slack_messages.py Tests (16 new, 45 total in test_triggers.py) -------------------------------------------- - TestSlackTriggerValidation x7: channel id regex (accept/reject lower case, accept various prefixes, reject channel names like '#general', reject empty list), invalid JMESPath rejected, SecretStr round-trip, discriminator dispatch. - TestSlackTriggerCreatesRun x9: happy path with payload + 'oldest' assertion, no-new-messages, ok=false treated as no messages, 429 rate limit treated as no messages, cutoff-edge message excluded, JMESPath keep/exclude, multi-channel parallel fan-out, disabled short-circuit (zero HTTP calls). - conftest.py: new patch_slack_transport fixture mirroring patch_github_transport. All 140 unit tests pass (45 trigger, 3 scheduler create-pending-run, 6 async_utils, plus broader scheduler/event/config sweeps). ruff format / ruff check clean. Co-authored-by: openhands --- openhands/automation/schemas.py | 479 ++++++++++++++++++++++++-------- tests/conftest.py | 37 +++ tests/test_triggers.py | 358 +++++++++++++++++++++++- 3 files changed, 758 insertions(+), 116 deletions(-) diff --git a/openhands/automation/schemas.py b/openhands/automation/schemas.py index 46dd7e4..48b0a89 100644 --- a/openhands/automation/schemas.py +++ b/openhands/automation/schemas.py @@ -7,7 +7,7 @@ import uuid from datetime import datetime from enum import StrEnum -from typing import TYPE_CHECKING, Annotated, Any, Literal +from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal from zoneinfo import ZoneInfo import httpx @@ -39,6 +39,10 @@ _GITHUB_REPO_RE = re.compile(r"^[A-Za-z0-9_.\-]+/[A-Za-z0-9_.\-]+$") +# Slack channel IDs are uppercase alphanumeric (e.g. 'C0123ABC', 'D…', 'G…'). +# Channel *names* aren't supported here because they're not stable across renames. +_SLACK_CHANNEL_ID_RE = re.compile(r"^[A-Z][A-Z0-9]+$") + # Allowed URI schemes for tarball_path (includes internal upload scheme) _TARBALL_SCHEME_RE = re.compile(r"^(s3|gs|https?|oh-internal)://") @@ -265,7 +269,154 @@ async def create_pending_run( return None -class GithubTrigger(_TriggerBase): +class _PollingTriggerBase(_TriggerBase): + """Base for triggers that poll external services for new items. + + Concrete subclasses (e.g. :class:`GithubTrigger`, :class:`SlackTrigger`) + own three pieces of behaviour: + + 1. **Authentication / client construction** — :meth:`_build_client`. + 2. **Fan-out fetching** — :meth:`_fetch_all_new` returns the flat list of + items (already filtered against the cutoff and :attr:`event_filter`) + collected from every configured resource, typically running per-resource + calls concurrently via :func:`wait_all`. + 3. **Payload labelling** — :attr:`_PAYLOAD_SOURCE` and :attr:`_PAYLOAD_KEY` + (e.g. ``"github_trigger"`` / ``"events"``). + + The base class itself handles: + + - The ``enabled`` / ``deleted_at`` short-circuit. + - Computing the cutoff (``last_triggered_at`` or ``created_at``). + - Calling :func:`openhands.automation.utils.run.create_pending_run` once + items are gathered and attaching them to ``run.event_payload``. + - Validating and compiling the optional JMESPath ``event_filter``. + """ + + # Pydantic config. ``arbitrary_types_allowed`` lets us cache a compiled + # JMESPath parser object on ``self.__dict__`` without Pydantic objecting. + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + + # Subclass-supplied metadata for the run's ``event_payload`` envelope. + _PAYLOAD_SOURCE: ClassVar[str] = "" + _PAYLOAD_KEY: ClassVar[str] = "items" + + event_filter: str | None = Field( + default=None, + description=( + "Optional JMESPath expression evaluated against each polled item " + "(see https://jmespath.org/). An item matches when the expression " + "returns a truthy value; non-matching items are dropped. The " + "trigger fires only if at least one item matches." + ), + ) + + @field_validator("event_filter") + @classmethod + def validate_event_filter(cls, v: str | None) -> str | None: + """Parse the JMESPath at validation time so bad syntax fails fast. + + The compiled expression itself is *not* stored on the field (so the + JSON round-trip stays clean) — it's cached lazily on first use via + :attr:`_compiled_event_filter`. + """ + if v is None: + return v + expr = v.strip() + if not expr: + return None + try: + jmespath.compile(expr) + except JMESPathError as e: + raise ValueError(f"Invalid JMESPath expression: {e}") from e + return expr + + @property + def _compiled_event_filter(self) -> jmespath.parser.ParsedResult | None: + """Compile and cache the JMESPath expression on first use.""" + if self.event_filter is None: + return None + cached = self.__dict__.get("_compiled_filter") + if cached is None: + # Already validated in ``validate_event_filter`` — won't raise. + cached = jmespath.compile(self.event_filter) + self.__dict__["_compiled_filter"] = cached + return cached + + def _item_matches_filter(self, item: dict[str, Any], context: str) -> bool: + """Return True if ``item`` passes ``event_filter`` (or no filter). + + ``context`` is a human-readable label (repo / channel id) used only + for logging if the JMESPath evaluator throws. + """ + compiled = self._compiled_event_filter + if compiled is None: + return True + try: + return bool(compiled.search(item)) + except JMESPathError as e: + logger.warning("JMESPath evaluation failed for item in %s: %s", context, e) + return False + + def _build_client(self) -> httpx.AsyncClient: + """Construct an authenticated HTTP client for the upstream service.""" + raise NotImplementedError + + async def _fetch_all_new( + self, + client: httpx.AsyncClient, + cutoff: datetime, + ) -> list[dict[str, Any]]: + """Return all items newer than ``cutoff`` from every configured resource. + + Implementations are expected to fan out across their resources + concurrently (typically via + :func:`openhands.automation.utils.async_utils.wait_all`) and apply + :meth:`_item_matches_filter` per item. + """ + raise NotImplementedError + + async def create_pending_run( + self, + session: AsyncSession, + automation: Automation, + now: datetime | None = None, # noqa: ARG002 + ) -> AutomationRun | None: + """Fire if any configured resource yields a matching new item. + + On fire, attaches the gathered items to ``run.event_payload`` as:: + + {"source": , + : [...]} + """ + # Deferred import avoids a circular dependency at module load. + from openhands.automation.utils.run import ( + create_pending_run as _create_run_util, + ) + + if not automation.enabled or automation.deleted_at is not None: + return None + + cutoff = automation.last_triggered_at or automation.created_at + if cutoff is None: + return None + if cutoff.tzinfo is None: + cutoff = cutoff.replace(tzinfo=ZoneInfo("UTC")) + + async with self._build_client() as client: + items = await self._fetch_all_new(client, cutoff) + + if not items: + return None + + run = await _create_run_util(session, automation) + run.event_payload = { + "source": self._PAYLOAD_SOURCE, + self._PAYLOAD_KEY: items, + } + return run + + +class GithubTrigger(_PollingTriggerBase): """Poll-based trigger that fires when new events appear on GitHub repos. On each scheduler poll, the trigger queries the @@ -292,9 +443,8 @@ class GithubTrigger(_TriggerBase): fires the trigger. """ - # Tell Pydantic that the cached compiled JMESPath object (a non-Pydantic - # type) is acceptable on the model; we expose it via a property. - model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + _PAYLOAD_SOURCE: ClassVar[str] = "github_trigger" + _PAYLOAD_KEY: ClassVar[str] = "events" type: Literal["github"] = "github" github_access_token: SecretStr = Field( @@ -313,17 +463,6 @@ class GithubTrigger(_TriggerBase): "(e.g. 'All-Hands-AI/OpenHands')." ), ) - event_filter: str | None = Field( - default=None, - description=( - "Optional JMESPath expression evaluated against each GitHub event " - "(see https://jmespath.org/). An event matches when the expression " - "returns a truthy value; non-matching events are dropped. The " - "trigger fires only if at least one event matches. Examples: " - "`type == 'PushEvent'`, " - "`type == 'PullRequestEvent' && payload.action == 'opened'`." - ), - ) @field_validator("repositories") @classmethod @@ -336,42 +475,9 @@ def validate_repositories(cls, v: list[str]) -> list[str]: cleaned.append(repo) return cleaned - @field_validator("event_filter") - @classmethod - def validate_event_filter(cls, v: str | None) -> str | None: - """Parse the JMESPath at validation time so bad syntax fails fast. - - The compiled expression itself is *not* stored on the field (so the - JSON round-trip is clean) — it's cached lazily on first use via - :attr:`_compiled_event_filter`. - """ - if v is None: - return v - expr = v.strip() - if not expr: - return None - try: - jmespath.compile(expr) - except JMESPathError as e: - raise ValueError(f"Invalid JMESPath expression: {e}") from e - return expr - - @property - def _compiled_event_filter(self) -> jmespath.parser.ParsedResult | None: - """Compile and cache the JMESPath expression on first use.""" - if self.event_filter is None: - return None - cached = self.__dict__.get("_compiled_filter") - if cached is None: - # Already validated in ``validate_event_filter`` — won't raise. - cached = jmespath.compile(self.event_filter) - self.__dict__["_compiled_filter"] = cached - return cached - @field_serializer("github_access_token", when_used="always") def _serialize_token(self, v: SecretStr) -> str: - """Emit the raw secret so the trigger can round-trip through the - JSON column. + """Emit the raw secret so the trigger round-trips through the JSON column. The token must be stored in plain text because the scheduler needs to read it back later to authenticate against GitHub. ``SecretStr`` @@ -406,12 +512,8 @@ async def _fetch_new_events( ) -> list[dict[str, Any]]: """Return all matching events for ``repo`` newer than ``cutoff``. - Applies :attr:`event_filter` (a JMESPath expression) if set: each - event is kept only when the expression evaluates to a truthy value. Errors (HTTP/JSON/non-200) are logged and treated as "no new events" - so a single bad repo doesn't take down the whole trigger. A - JMESPath evaluation error on an individual event is logged and the - event is dropped (treated as non-matching). + so a single bad repo doesn't take down the whole trigger. """ try: resp = await client.get( @@ -438,22 +540,12 @@ async def _fetch_new_events( if not isinstance(events, list): return [] - compiled = self._compiled_event_filter new_events: list[dict[str, Any]] = [] for ev in events: if not isinstance(ev, dict): continue - if compiled is not None: - try: - if not compiled.search(ev): - continue - except JMESPathError as e: - logger.warning( - "JMESPath evaluation failed for event in %s: %s", - repo, - e, - ) - continue + if not self._item_matches_filter(ev, repo): + continue created_raw = ev.get("created_at") if not isinstance(created_raw, str): continue @@ -464,73 +556,239 @@ async def _fetch_new_events( if created_at.tzinfo is None: created_at = created_at.replace(tzinfo=ZoneInfo("UTC")) if created_at > cutoff: - # Tag with repo so downstream code knows where it came from. + # Tag with the source repo so downstream code knows the origin. tagged = dict(ev) tagged.setdefault("_repository", repo) new_events.append(tagged) return new_events - async def create_pending_run( + async def _fetch_all_new( self, - session: AsyncSession, - automation: Automation, - now: datetime | None = None, # noqa: ARG002 - ) -> AutomationRun | None: - """Fire if any configured repo has matching new events. + client: httpx.AsyncClient, + cutoff: datetime, + ) -> list[dict[str, Any]]: + from openhands.automation.utils.async_utils import wait_all + + per_repo: list[list[dict[str, Any]]] = await wait_all( + [ + self._fetch_new_events(client, repo, cutoff) + for repo in self.repositories + ], + timeout=None, + ) + return [ev for batch in per_repo for ev in batch] + - On fire, attaches the events that caused the fire to - ``run.event_payload`` as:: +class SlackTrigger(_PollingTriggerBase): + """Poll-based trigger that fires when new messages appear in Slack channels. - { - "source": "github_trigger", - "events": [, ...], - } + On each scheduler poll, the trigger calls + ``https://slack.com/api/conversations.history`` for each configured + channel **concurrently**, using ``oldest=`` as the + high-water mark (derived from the automation's last fire time, or + ``created_at`` for the very first poll). Messages with a Slack ``ts`` + strictly greater than the cutoff are kept. + + Optionally narrow which messages fire the trigger using + :attr:`event_filter` — a `JMESPath `_ expression + evaluated against each message dict. A message is kept when the + expression returns a truthy value. Examples:: + + subtype == null # regular user messages only + type == 'message' && user == 'U0123ABC' + contains(text, '@here') + + Required Slack OAuth scopes depend on the channel types polled: + ``channels:history``, ``groups:history``, ``im:history``, ``mpim:history``. + Bot tokens (``xoxb-…``) work but the bot must be a member of each + channel; user tokens (``xoxp-…``) generally don't need membership. + + Notes: + + - The Slack Web API returns HTTP 200 even for application errors; we + detect those via ``body["ok"]`` and log ``body["error"]``. + - On HTTP 429 (rate limit) we log the ``Retry-After`` value and treat + the channel as having no new messages this cycle. The next scheduler + tick will retry. + - Only the first page (up to 200 messages) is fetched per channel per + poll. Channels exceeding that between polls will skip the gap — + tighten ``scheduler_interval`` or shorten poll cycles to compensate. + """ + + _PAYLOAD_SOURCE: ClassVar[str] = "slack_trigger" + _PAYLOAD_KEY: ClassVar[str] = "messages" + _SLACK_MESSAGE_PAGE_LIMIT: ClassVar[int] = 200 + + type: Literal["slack"] = "slack" + slack_token: SecretStr = Field( + ..., + description=( + "Slack token used to authenticate against the Web API. Either a " + "user token ('xoxp-…') or a bot token ('xoxb-…'). The token " + "needs the appropriate '*:history' OAuth scopes for the " + "channel types being polled." + ), + ) + channels: list[str] = Field( + ..., + min_length=1, + description=( + "Slack channel IDs to poll (e.g. 'C0123ABC'). IDs — not names — " + "are required because names are not stable across renames. Find a " + "channel's ID in its 'About' panel inside the Slack client." + ), + ) + + @field_validator("channels") + @classmethod + def validate_channels(cls, v: list[str]) -> list[str]: + cleaned: list[str] = [] + for c in v: + c = c.strip() + if not _SLACK_CHANNEL_ID_RE.match(c): + raise ValueError( + f"Invalid Slack channel id {c!r}: expected uppercase " + "alphanumeric like 'C0123ABC' (channel IDs, not names)." + ) + cleaned.append(c) + return cleaned + + @field_serializer("slack_token", when_used="always") + def _serialize_token(self, v: SecretStr) -> str: + """Emit the raw secret so the trigger round-trips through the JSON column. + + See the corresponding note on :class:`GithubTrigger` — the token is + persisted in plain text because the scheduler must reuse it on every + poll. ``SecretStr`` still protects against accidental logging at the + application layer. """ - # Deferred imports avoid circular dependencies at module load time. - from openhands.automation.utils.async_utils import wait_all - from openhands.automation.utils.run import ( - create_pending_run as _create_run_util, + return v.get_secret_value() + + def _build_client(self) -> httpx.AsyncClient: + """Construct an authenticated Slack Web API client.""" + return httpx.AsyncClient( + base_url="https://slack.com/api", + headers={ + "Accept": "application/json", + "User-Agent": "openhands-automation", + "Authorization": f"Bearer {self.slack_token.get_secret_value()}", + }, + timeout=30.0, ) - if not automation.enabled or automation.deleted_at is not None: - return None + async def _fetch_new_messages( + self, + client: httpx.AsyncClient, + channel: str, + cutoff: datetime, + ) -> list[dict[str, Any]]: + """Return matching messages for ``channel`` newer than ``cutoff``. - cutoff = automation.last_triggered_at or automation.created_at - if cutoff is None: - return None - if cutoff.tzinfo is None: - cutoff = cutoff.replace(tzinfo=ZoneInfo("UTC")) + Uses Slack's ``oldest`` parameter (exclusive) for server-side + filtering. Errors — HTTP, JSON, rate-limit, or ``ok=false`` — + are logged and treated as "no new messages" so a single bad + channel cannot take down the whole trigger. + """ + oldest = f"{cutoff.timestamp():.6f}" + try: + resp = await client.get( + "/conversations.history", + params={ + "channel": channel, + "oldest": oldest, + "limit": self._SLACK_MESSAGE_PAGE_LIMIT, + }, + ) + except httpx.HTTPError as e: + logger.warning("Slack poll failed for %s: %s", channel, e) + return [] - async with self._build_client() as client: - per_repo: list[list[dict[str, Any]]] = await wait_all( - [ - self._fetch_new_events(client, repo, cutoff) - for repo in self.repositories - ], - timeout=None, + if resp.status_code == 429: + retry_after = resp.headers.get("Retry-After", "unknown") + logger.warning( + "Slack poll for %s rate-limited (Retry-After=%s)", + channel, + retry_after, ) + return [] + if resp.status_code != 200: + logger.warning( + "Slack poll for %s returned status %s", + channel, + resp.status_code, + ) + return [] - all_events: list[dict[str, Any]] = [ev for batch in per_repo for ev in batch] - if not all_events: - return None + try: + body = resp.json() + except ValueError: + logger.warning("Slack poll for %s returned non-JSON body", channel) + return [] + if not isinstance(body, dict): + return [] + if not body.get("ok"): + logger.warning( + "Slack poll for %s returned error: %s", + channel, + body.get("error", "unknown"), + ) + return [] - run = await _create_run_util(session, automation) - run.event_payload = {"source": "github_trigger", "events": all_events} - return run + messages = body.get("messages") or [] + if not isinstance(messages, list): + return [] + + # Slack returns newest-first; deliver in natural chronological order. + messages_sorted = sorted(messages, key=lambda m: float(m.get("ts", "0") or "0")) + + new_messages: list[dict[str, Any]] = [] + for msg in messages_sorted: + if not isinstance(msg, dict): + continue + ts_raw = msg.get("ts") + if not isinstance(ts_raw, str): + continue + try: + ts_unix = float(ts_raw) + except ValueError: + continue + # Defense in depth — `oldest` is exclusive but check anyway. + if ts_unix <= cutoff.timestamp(): + continue + if not self._item_matches_filter(msg, channel): + continue + tagged = dict(msg) + tagged.setdefault("_channel", channel) + new_messages.append(tagged) + return new_messages + + async def _fetch_all_new( + self, + client: httpx.AsyncClient, + cutoff: datetime, + ) -> list[dict[str, Any]]: + from openhands.automation.utils.async_utils import wait_all + + per_channel: list[list[dict[str, Any]]] = await wait_all( + [self._fetch_new_messages(client, c, cutoff) for c in self.channels], + timeout=None, + ) + return [m for batch in per_channel for m in batch] def _get_trigger_discriminator(v: dict | BaseModel) -> str: """Discriminator function for Pydantic's discriminated union. Returns the trigger type string, which Pydantic uses to select the - correct model (CronTrigger, EventTrigger, or GithubTrigger) from the union. + correct model (CronTrigger, EventTrigger, GithubTrigger, or SlackTrigger) + from the union. Why sentinel instead of raising ValueError: Pydantic discriminator functions must return a string - they cannot raise exceptions. By returning an invalid sentinel value, Pydantic generates a proper ValidationError with context like: "Input tag '__missing_trigger_type__' found using 'type' does not - match any of the expected tags: 'cron', 'event', 'github'" + match any of the expected tags: 'cron', 'event', 'github', 'slack'" This produces a user-friendly 422 response via FastAPI. """ if isinstance(v, dict): @@ -545,15 +803,16 @@ def _get_trigger_discriminator(v: dict | BaseModel) -> str: Trigger = Annotated[ Annotated[CronTrigger, Tag("cron")] | Annotated[EventTrigger, Tag("event")] - | Annotated[GithubTrigger, Tag("github")], + | Annotated[GithubTrigger, Tag("github")] + | Annotated[SlackTrigger, Tag("slack")], Discriminator(_get_trigger_discriminator), ] # Reusable adapter for parsing trigger dicts (e.g. ``automation.trigger`` JSON) # into the correct ``_TriggerBase`` subclass. -TriggerAdapter: TypeAdapter[CronTrigger | EventTrigger | GithubTrigger] = TypeAdapter( - Trigger -) +TriggerAdapter: TypeAdapter[ + CronTrigger | EventTrigger | GithubTrigger | SlackTrigger +] = TypeAdapter(Trigger) class RunStatus(StrEnum): diff --git a/tests/conftest.py b/tests/conftest.py index 593b729..d3c0fec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -274,3 +274,40 @@ def _patched(self: GithubTrigger) -> httpx.AsyncClient: return seen return install + + +@pytest.fixture +def patch_slack_transport(monkeypatch): + """Inject an ``httpx.MockTransport`` into ``SlackTrigger._build_client``. + + Returns a callable: given a responder ``(httpx.Request) -> httpx.Response``, + installs it and returns the list of requests the trigger ends up issuing. + """ + import httpx + + from openhands.automation.schemas import SlackTrigger + + def install(responder): + seen: list[httpx.Request] = [] + + def capture(request: httpx.Request) -> httpx.Response: + seen.append(request) + return responder(request) + + transport = httpx.MockTransport(capture) + + def _patched(self: SlackTrigger) -> httpx.AsyncClient: + return httpx.AsyncClient( + base_url="https://slack.com/api", + headers={ + "Accept": "application/json", + "User-Agent": "openhands-automation", + "Authorization": f"Bearer {self.slack_token.get_secret_value()}", + }, + transport=transport, + ) + + monkeypatch.setattr(SlackTrigger, "_build_client", _patched) + return seen + + return install diff --git a/tests/test_triggers.py b/tests/test_triggers.py index d7512e6..4a7ce37 100644 --- a/tests/test_triggers.py +++ b/tests/test_triggers.py @@ -1,12 +1,14 @@ """Tests for the trigger schemas and their ``create_pending_run`` methods. These tests exercise the real :class:`CronTrigger`/:class:`EventTrigger`/ -:class:`GithubTrigger` code paths against a real (SQLite in-memory) session, -so :func:`openhands.automation.utils.run.create_pending_run` runs against a -real database. Only the GitHub HTTP transport is replaced via -``httpx.MockTransport`` (see the ``patch_github_transport`` fixture); the -``GithubTrigger`` code — including header/auth construction, response -parsing, timestamp filtering, and parallel repo fan-out — is untouched. +:class:`GithubTrigger`/:class:`SlackTrigger` code paths against a real +(SQLite in-memory) session, so +:func:`openhands.automation.utils.run.create_pending_run` runs against a +real database. Only the GitHub/Slack HTTP transports are replaced via +``httpx.MockTransport`` (see the ``patch_github_transport`` / +``patch_slack_transport`` fixtures); the trigger code — including +header/auth construction, response parsing, timestamp filtering, and +parallel resource fan-out — is untouched. """ from __future__ import annotations @@ -24,6 +26,7 @@ from openhands.automation.schemas import ( EventTrigger, GithubTrigger, + SlackTrigger, TriggerAdapter, ) @@ -490,3 +493,346 @@ def responder(req: httpx.Request) -> httpx.Response: # Both repos were polled (parallel fan-out). polled = {r.url.path for r in seen} assert polled == {"/repos/foo/bar/events", "/repos/foo/baz/events"} + + +# --------------------------------------------------------------------------- +# SlackTrigger +# --------------------------------------------------------------------------- + + +def _slack_message( + msg_id: str, + ts_dt: datetime, + *, + text: str = "hello", + user: str = "U0123ABC", + msg_type: str = "message", + subtype: str | None = None, +) -> dict[str, Any]: + """Build a fake Slack message dict. + + Slack ``ts`` is a string of unix seconds with microsecond precision. + """ + msg: dict[str, Any] = { + "type": msg_type, + "user": user, + "text": text, + "ts": f"{ts_dt.timestamp():.6f}", + "client_msg_id": msg_id, + } + if subtype is not None: + msg["subtype"] = subtype + return msg + + +def _slack_ok(messages: list[dict[str, Any]]) -> dict[str, Any]: + return {"ok": True, "messages": messages, "has_more": False} + + +class TestSlackTriggerValidation: + def test_valid_channel_ids_are_accepted(self): + trigger = SlackTrigger( + slack_token=SecretStr("xoxb-xxx"), + channels=["C0123ABC", "D0SOMEDM"], + ) + assert trigger.channels == ["C0123ABC", "D0SOMEDM"] + + def test_lowercase_channel_id_rejected(self): + with pytest.raises(ValidationError, match="Invalid Slack channel id"): + SlackTrigger( + slack_token=SecretStr("xoxb-xxx"), + channels=["c0123abc"], + ) + + def test_channel_name_rejected(self): + # '#general' isn't a stable id; ensure we reject names early. + with pytest.raises(ValidationError, match="Invalid Slack channel id"): + SlackTrigger( + slack_token=SecretStr("xoxb-xxx"), + channels=["#general"], + ) + + def test_empty_channels_rejected(self): + with pytest.raises(ValidationError): + SlackTrigger(slack_token=SecretStr("xoxb-xxx"), channels=[]) + + def test_invalid_jmespath_event_filter_rejected(self): + # The JMESPath validation lives on the shared polling base, so this + # also exercises that SlackTrigger inherits it correctly. + with pytest.raises(ValidationError, match="Invalid JMESPath expression"): + SlackTrigger( + slack_token=SecretStr("xoxb-xxx"), + channels=["C0123ABC"], + event_filter="not valid @@@", + ) + + def test_token_round_trips_through_model_dump(self): + """``model_dump`` must emit the raw string so the trigger can be + persisted in the AutomationsTrigger JSON column and rehydrated.""" + trigger = SlackTrigger( + slack_token=SecretStr("xoxb-supersecret"), + channels=["C0123ABC"], + ) + dumped = trigger.model_dump(mode="python") + assert dumped["slack_token"] == "xoxb-supersecret" + # Round-trip through the discriminated union too. + roundtripped = TriggerAdapter.validate_python(dumped) + assert isinstance(roundtripped, SlackTrigger) + assert roundtripped.slack_token.get_secret_value() == "xoxb-supersecret" + + def test_discriminated_union_dispatches_to_slack(self): + parsed = TriggerAdapter.validate_python( + { + "type": "slack", + "slack_token": "xoxb-xxx", + "channels": ["C0123ABC"], + } + ) + assert isinstance(parsed, SlackTrigger) + + +class TestSlackTriggerCreatesRun: + async def test_new_message_creates_run_with_payload( + self, sqlite_session, patch_slack_transport + ): + import httpx + + cutoff = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + msg = _slack_message("m1", cutoff + timedelta(seconds=30)) + seen = patch_slack_transport( + lambda req: httpx.Response(200, json=_slack_ok([msg])) + ) + + trigger = SlackTrigger( + slack_token=SecretStr("xoxb-xxx"), + channels=["C0123ABC"], + ) + automation = await _make_automation( + sqlite_session, + trigger=trigger.model_dump(mode="python"), + last_triggered_at=cutoff, + ) + + run = await trigger.create_pending_run(sqlite_session, automation) + assert run is not None + assert run.event_payload["source"] == "slack_trigger" + kept = run.event_payload["messages"] + assert [m["client_msg_id"] for m in kept] == ["m1"] + # Channel was tagged onto the message. + assert kept[0]["_channel"] == "C0123ABC" + # The Slack `oldest` param uses unix seconds derived from the cutoff. + assert len(seen) == 1 + assert seen[0].url.path == "/api/conversations.history" + assert seen[0].url.params["channel"] == "C0123ABC" + assert float(seen[0].url.params["oldest"]) == pytest.approx(cutoff.timestamp()) + # And the run was persisted. + rows = (await sqlite_session.execute(select(AutomationRun))).scalars().all() + assert len(rows) == 1 + + async def test_no_new_messages_returns_none( + self, sqlite_session, patch_slack_transport + ): + import httpx + + cutoff = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + patch_slack_transport(lambda req: httpx.Response(200, json=_slack_ok([]))) + + trigger = SlackTrigger( + slack_token=SecretStr("xoxb-xxx"), + channels=["C0123ABC"], + ) + automation = await _make_automation( + sqlite_session, + trigger=trigger.model_dump(mode="python"), + last_triggered_at=cutoff, + ) + assert await trigger.create_pending_run(sqlite_session, automation) is None + rows = (await sqlite_session.execute(select(AutomationRun))).scalars().all() + assert rows == [] + + async def test_slack_ok_false_treated_as_no_messages( + self, sqlite_session, patch_slack_transport + ): + """Slack returns HTTP 200 for app errors; ``ok=false`` must not fire.""" + import httpx + + cutoff = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + patch_slack_transport( + lambda req: httpx.Response( + 200, json={"ok": False, "error": "channel_not_found"} + ) + ) + + trigger = SlackTrigger( + slack_token=SecretStr("xoxb-xxx"), + channels=["C0123ABC"], + ) + automation = await _make_automation( + sqlite_session, + trigger=trigger.model_dump(mode="python"), + last_triggered_at=cutoff, + ) + assert await trigger.create_pending_run(sqlite_session, automation) is None + + async def test_rate_limit_treated_as_no_messages( + self, sqlite_session, patch_slack_transport + ): + """HTTP 429 (rate limit) is logged and treated as no new messages.""" + import httpx + + cutoff = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + patch_slack_transport( + lambda req: httpx.Response(429, headers={"Retry-After": "5"}, json={}) + ) + + trigger = SlackTrigger( + slack_token=SecretStr("xoxb-xxx"), + channels=["C0123ABC"], + ) + automation = await _make_automation( + sqlite_session, + trigger=trigger.model_dump(mode="python"), + last_triggered_at=cutoff, + ) + assert await trigger.create_pending_run(sqlite_session, automation) is None + + async def test_message_at_or_before_cutoff_excluded( + self, sqlite_session, patch_slack_transport + ): + """A message exactly at the cutoff or older must not fire the trigger.""" + import httpx + + cutoff = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + # Slack's `oldest` is server-side exclusive — but it's perfectly + # reasonable for the API to still hand us a message at the cutoff + # (e.g. if the float rounds). Our client-side check must reject it. + stale = _slack_message("m_old", cutoff) + patch_slack_transport(lambda req: httpx.Response(200, json=_slack_ok([stale]))) + + trigger = SlackTrigger( + slack_token=SecretStr("xoxb-xxx"), + channels=["C0123ABC"], + ) + automation = await _make_automation( + sqlite_session, + trigger=trigger.model_dump(mode="python"), + last_triggered_at=cutoff, + ) + assert await trigger.create_pending_run(sqlite_session, automation) is None + + async def test_jmespath_filter_keeps_matching_messages( + self, sqlite_session, patch_slack_transport + ): + """A truthy JMESPath result keeps the message and fires the trigger.""" + import httpx + + cutoff = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + msgs = [ + _slack_message( + "join", + cutoff + timedelta(seconds=10), + subtype="channel_join", + text="<@U1> has joined", + ), + _slack_message( + "real", + cutoff + timedelta(seconds=20), + text="real user message", + ), + ] + patch_slack_transport(lambda req: httpx.Response(200, json=_slack_ok(msgs))) + + # Filter out join/leave noise — only regular user messages. + trigger = SlackTrigger( + slack_token=SecretStr("xoxb-xxx"), + channels=["C0123ABC"], + event_filter="subtype == null", + ) + automation = await _make_automation( + sqlite_session, + trigger=trigger.model_dump(mode="python"), + last_triggered_at=cutoff, + ) + run = await trigger.create_pending_run(sqlite_session, automation) + assert run is not None + kept = run.event_payload["messages"] + assert [m["client_msg_id"] for m in kept] == ["real"] + + async def test_jmespath_filter_excludes_all( + self, sqlite_session, patch_slack_transport + ): + """A filter that matches nothing must not fire the trigger.""" + import httpx + + cutoff = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + msgs = [ + _slack_message("a", cutoff + timedelta(seconds=10), user="U1"), + _slack_message("b", cutoff + timedelta(seconds=20), user="U2"), + ] + patch_slack_transport(lambda req: httpx.Response(200, json=_slack_ok(msgs))) + + trigger = SlackTrigger( + slack_token=SecretStr("xoxb-xxx"), + channels=["C0123ABC"], + event_filter="user == 'UNOBODY'", + ) + automation = await _make_automation( + sqlite_session, + trigger=trigger.model_dump(mode="python"), + last_triggered_at=cutoff, + ) + assert await trigger.create_pending_run(sqlite_session, automation) is None + + async def test_multiple_channels_polled_in_parallel( + self, sqlite_session, patch_slack_transport + ): + """Two configured channels are each polled; both contribute messages.""" + import httpx + + cutoff = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + + def responder(request: httpx.Request) -> httpx.Response: + ch = request.url.params["channel"] + msg = _slack_message(f"msg_{ch}", cutoff + timedelta(seconds=5)) + return httpx.Response(200, json=_slack_ok([msg])) + + seen = patch_slack_transport(responder) + + trigger = SlackTrigger( + slack_token=SecretStr("xoxb-xxx"), + channels=["C0123ABC", "C0456DEF"], + ) + automation = await _make_automation( + sqlite_session, + trigger=trigger.model_dump(mode="python"), + last_triggered_at=cutoff, + ) + + run = await trigger.create_pending_run(sqlite_session, automation) + assert run is not None + kept = run.event_payload["messages"] + assert {m["_channel"] for m in kept} == {"C0123ABC", "C0456DEF"} + polled = {r.url.params["channel"] for r in seen} + assert polled == {"C0123ABC", "C0456DEF"} + + async def test_disabled_short_circuits_without_calling_slack( + self, sqlite_session, patch_slack_transport + ): + """A disabled automation must not call Slack at all.""" + import httpx + + seen = patch_slack_transport( + lambda req: httpx.Response(200, json=_slack_ok([])) + ) + + trigger = SlackTrigger( + slack_token=SecretStr("xoxb-xxx"), + channels=["C0123ABC"], + ) + automation = await _make_automation( + sqlite_session, + trigger=trigger.model_dump(mode="python"), + enabled=False, + ) + assert await trigger.create_pending_run(sqlite_session, automation) is None + assert seen == []