From 8f923be3a1ee44528087660a3c08874e452f9c78 Mon Sep 17 00:00:00 2001 From: Qasim Date: Fri, 26 Jun 2026 08:19:54 -0400 Subject: [PATCH 1/2] TW-5722: add JSON-RPC server over WebSocket exposing the CLI surface MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add `nylas rpc serve` — a JSON-RPC 2.0 server over WebSocket exposing the CLI surface (email, calendar, contacts, threads, drafts, agent accounts, scheduler, notetaker, admin, auth grant ops) as a second inbound adapter over the existing ports.*Client interfaces. - Full CRUD parity for email, calendar/events, contacts, and API grant ops; folders, attachments, signatures, scheduled sends, free/busy, availability, recurring instances, room resources, and virtual-calendar grants - Attachment download capped at 30 MiB; per-grant resolution with default grant - Incremental pollers (messages, threads, events, contacts) with notifications - Integration suite launches the real binary over WebSocket (read-only smoke + protocol/notification/read/write coverage) - docs/RPC.md documents the method surface Also correct the agent rule `assign_to_folder` hint and docs: the value is now a folder name (UUID folder IDs no longer resolve), matching the API change in developer.nylas.com#630. --- Makefile | 17 +- cmd/nylas/main.go | 2 + docs/RPC.md | 405 +++++++++++++ docs/commands/agent-rule.md | 9 + go.mod | 2 +- internal/adapters/nylas/calendars_events.go | 2 + .../adapters/nylas/calendars_events_test.go | 8 + internal/adapters/nylas/contacts.go | 2 + .../adapters/nylas/contacts_internal_test.go | 17 + internal/adapters/nylas/demo_threads.go | 7 + internal/adapters/nylas/mock_client.go | 1 + internal/adapters/nylas/mock_threads.go | 14 + internal/adapters/nylas/threads.go | 27 +- internal/adapters/nylas/threads_test.go | 60 +- internal/adapters/rpcserver/auth.go | 93 +++ internal/adapters/rpcserver/auth_test.go | 244 ++++++++ internal/adapters/rpcserver/handlers.go | 27 + internal/adapters/rpcserver/handlers_admin.go | 523 ++++++++++++++++ .../adapters/rpcserver/handlers_admin_test.go | 461 +++++++++++++++ internal/adapters/rpcserver/handlers_audit.go | 154 +++++ .../adapters/rpcserver/handlers_audit_test.go | 419 +++++++++++++ internal/adapters/rpcserver/handlers_auth.go | 128 ++++ .../adapters/rpcserver/handlers_auth_test.go | 474 +++++++++++++++ .../adapters/rpcserver/handlers_calendar.go | 117 ++++ .../rpcserver/handlers_calendar_ext.go | 360 +++++++++++ .../rpcserver/handlers_calendar_ext_test.go | 476 +++++++++++++++ .../rpcserver/handlers_calendar_test.go | 227 +++++++ .../rpcserver/handlers_calendar_write.go | 128 ++++ .../rpcserver/handlers_calendar_write_test.go | 283 +++++++++ .../rpcserver/handlers_contact_ext.go | 168 ++++++ .../rpcserver/handlers_contact_ext_test.go | 263 +++++++++ .../rpcserver/handlers_contact_write.go | 87 +++ .../rpcserver/handlers_contact_write_test.go | 205 +++++++ .../adapters/rpcserver/handlers_contacts.go | 76 +++ .../rpcserver/handlers_contacts_test.go | 214 +++++++ internal/adapters/rpcserver/handlers_draft.go | 61 ++ .../adapters/rpcserver/handlers_draft_test.go | 198 +++++++ internal/adapters/rpcserver/handlers_email.go | 78 +++ .../adapters/rpcserver/handlers_email_ext.go | 505 ++++++++++++++++ .../rpcserver/handlers_email_ext_test.go | 557 ++++++++++++++++++ .../adapters/rpcserver/handlers_email_test.go | 236 ++++++++ .../rpcserver/handlers_email_write.go | 189 ++++++ .../rpcserver/handlers_email_write_test.go | 287 +++++++++ internal/adapters/rpcserver/handlers_local.go | 122 ++++ .../adapters/rpcserver/handlers_local_test.go | 247 ++++++++ .../adapters/rpcserver/handlers_notetaker.go | 195 ++++++ .../rpcserver/handlers_notetaker_test.go | 463 +++++++++++++++ internal/adapters/rpcserver/handlers_otp.go | 41 ++ .../adapters/rpcserver/handlers_otp_test.go | 102 ++++ .../adapters/rpcserver/handlers_scheduler.go | 372 ++++++++++++ .../rpcserver/handlers_scheduler_test.go | 277 +++++++++ .../adapters/rpcserver/handlers_templates.go | 421 +++++++++++++ .../rpcserver/handlers_templates_test.go | 458 ++++++++++++++ .../adapters/rpcserver/handlers_thread.go | 80 +++ .../rpcserver/handlers_thread_test.go | 295 ++++++++++ .../rpcserver/handlers_thread_write.go | 64 ++ .../rpcserver/handlers_thread_write_test.go | 205 +++++++ internal/adapters/rpcserver/incremental.go | 166 ++++++ .../adapters/rpcserver/incremental_test.go | 108 ++++ internal/adapters/rpcserver/jsonrpc.go | 154 +++++ internal/adapters/rpcserver/jsonrpc_test.go | 265 +++++++++ .../adapters/rpcserver/poller_contacts.go | 192 ++++++ .../rpcserver/poller_contacts_test.go | 475 +++++++++++++++ internal/adapters/rpcserver/poller_events.go | 90 +++ .../adapters/rpcserver/poller_events_test.go | 322 ++++++++++ .../adapters/rpcserver/poller_messages.go | 100 ++++ .../rpcserver/poller_messages_test.go | 378 ++++++++++++ internal/adapters/rpcserver/poller_threads.go | 91 +++ .../adapters/rpcserver/poller_threads_test.go | 305 ++++++++++ internal/adapters/rpcserver/server.go | 208 +++++++ internal/adapters/rpcserver/server_test.go | 241 ++++++++ internal/cli/agent/rule_validation.go | 2 +- .../cli/integration/rpc_ext_smoke_test.go | 78 +++ internal/cli/integration/rpc_extended_test.go | 78 +++ .../cli/integration/rpc_notifications_test.go | 58 ++ internal/cli/integration/rpc_protocol_test.go | 212 +++++++ internal/cli/integration/rpc_reads_test.go | 224 +++++++ internal/cli/integration/rpc_test.go | 161 +++++ internal/cli/integration/rpc_testutil_test.go | 170 ++++++ internal/cli/integration/rpc_writes_test.go | 187 ++++++ internal/cli/rpc/rpc.go | 16 + internal/cli/rpc/serve.go | 197 +++++++ internal/domain/calendar.go | 2 + internal/domain/contact.go | 1 + internal/ports/messages.go | 3 + 85 files changed, 15624 insertions(+), 13 deletions(-) create mode 100644 docs/RPC.md create mode 100644 internal/adapters/nylas/contacts_internal_test.go create mode 100644 internal/adapters/rpcserver/auth.go create mode 100644 internal/adapters/rpcserver/auth_test.go create mode 100644 internal/adapters/rpcserver/handlers.go create mode 100644 internal/adapters/rpcserver/handlers_admin.go create mode 100644 internal/adapters/rpcserver/handlers_admin_test.go create mode 100644 internal/adapters/rpcserver/handlers_audit.go create mode 100644 internal/adapters/rpcserver/handlers_audit_test.go create mode 100644 internal/adapters/rpcserver/handlers_auth.go create mode 100644 internal/adapters/rpcserver/handlers_auth_test.go create mode 100644 internal/adapters/rpcserver/handlers_calendar.go create mode 100644 internal/adapters/rpcserver/handlers_calendar_ext.go create mode 100644 internal/adapters/rpcserver/handlers_calendar_ext_test.go create mode 100644 internal/adapters/rpcserver/handlers_calendar_test.go create mode 100644 internal/adapters/rpcserver/handlers_calendar_write.go create mode 100644 internal/adapters/rpcserver/handlers_calendar_write_test.go create mode 100644 internal/adapters/rpcserver/handlers_contact_ext.go create mode 100644 internal/adapters/rpcserver/handlers_contact_ext_test.go create mode 100644 internal/adapters/rpcserver/handlers_contact_write.go create mode 100644 internal/adapters/rpcserver/handlers_contact_write_test.go create mode 100644 internal/adapters/rpcserver/handlers_contacts.go create mode 100644 internal/adapters/rpcserver/handlers_contacts_test.go create mode 100644 internal/adapters/rpcserver/handlers_draft.go create mode 100644 internal/adapters/rpcserver/handlers_draft_test.go create mode 100644 internal/adapters/rpcserver/handlers_email.go create mode 100644 internal/adapters/rpcserver/handlers_email_ext.go create mode 100644 internal/adapters/rpcserver/handlers_email_ext_test.go create mode 100644 internal/adapters/rpcserver/handlers_email_test.go create mode 100644 internal/adapters/rpcserver/handlers_email_write.go create mode 100644 internal/adapters/rpcserver/handlers_email_write_test.go create mode 100644 internal/adapters/rpcserver/handlers_local.go create mode 100644 internal/adapters/rpcserver/handlers_local_test.go create mode 100644 internal/adapters/rpcserver/handlers_notetaker.go create mode 100644 internal/adapters/rpcserver/handlers_notetaker_test.go create mode 100644 internal/adapters/rpcserver/handlers_otp.go create mode 100644 internal/adapters/rpcserver/handlers_otp_test.go create mode 100644 internal/adapters/rpcserver/handlers_scheduler.go create mode 100644 internal/adapters/rpcserver/handlers_scheduler_test.go create mode 100644 internal/adapters/rpcserver/handlers_templates.go create mode 100644 internal/adapters/rpcserver/handlers_templates_test.go create mode 100644 internal/adapters/rpcserver/handlers_thread.go create mode 100644 internal/adapters/rpcserver/handlers_thread_test.go create mode 100644 internal/adapters/rpcserver/handlers_thread_write.go create mode 100644 internal/adapters/rpcserver/handlers_thread_write_test.go create mode 100644 internal/adapters/rpcserver/incremental.go create mode 100644 internal/adapters/rpcserver/incremental_test.go create mode 100644 internal/adapters/rpcserver/jsonrpc.go create mode 100644 internal/adapters/rpcserver/jsonrpc_test.go create mode 100644 internal/adapters/rpcserver/poller_contacts.go create mode 100644 internal/adapters/rpcserver/poller_contacts_test.go create mode 100644 internal/adapters/rpcserver/poller_events.go create mode 100644 internal/adapters/rpcserver/poller_events_test.go create mode 100644 internal/adapters/rpcserver/poller_messages.go create mode 100644 internal/adapters/rpcserver/poller_messages_test.go create mode 100644 internal/adapters/rpcserver/poller_threads.go create mode 100644 internal/adapters/rpcserver/poller_threads_test.go create mode 100644 internal/adapters/rpcserver/server.go create mode 100644 internal/adapters/rpcserver/server_test.go create mode 100644 internal/cli/integration/rpc_ext_smoke_test.go create mode 100644 internal/cli/integration/rpc_extended_test.go create mode 100644 internal/cli/integration/rpc_notifications_test.go create mode 100644 internal/cli/integration/rpc_protocol_test.go create mode 100644 internal/cli/integration/rpc_reads_test.go create mode 100644 internal/cli/integration/rpc_test.go create mode 100644 internal/cli/integration/rpc_testutil_test.go create mode 100644 internal/cli/integration/rpc_writes_test.go create mode 100644 internal/cli/rpc/rpc.go create mode 100644 internal/cli/rpc/serve.go diff --git a/Makefile b/Makefile index 9692735..6c7aae6 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: build test-unit test-race test-integration test-integration-fast test-cli-regressions test-integration-agent test-cleanup test-coverage test-air test-air-integration test-e2e test-e2e-air test-e2e-ui test-playwright test-playwright-air test-playwright-ui test-playwright-studio test-playwright-interactive test-playwright-headed clean clean-cache install fmt vet lint vuln deps security check-context ci ci-full help +.PHONY: build test-unit test-race test-integration test-integration-fast test-cli-regressions test-integration-agent test-integration-rpc test-cleanup test-coverage test-air test-air-integration test-e2e test-e2e-air test-e2e-ui test-playwright test-playwright-air test-playwright-ui test-playwright-studio test-playwright-interactive test-playwright-headed clean clean-cache install fmt vet lint vuln deps security check-context ci ci-full help # Disable parallel Make execution - prevents Go build cache corruption on btrfs (CachyOS) .NOTPARALLEL: @@ -186,6 +186,21 @@ test-integration-agent: build -run 'TestCLI_Agent.*$$' @echo "✓ Agent integration checks passed" +# RPC WebSocket server integration checks: boots `nylas rpc serve`, verifies token auth +# (wrong token rejected, correct token connects) and a live email.list over JSON-RPC. +test-integration-rpc: build + @echo "=== Running RPC Server Integration Checks ===" + @: "$${NYLAS_API_KEY:?NYLAS_API_KEY is required for rpc integration tests}" + @: "$${NYLAS_GRANT_ID:?NYLAS_GRANT_ID is required for rpc integration tests}" + @go clean -testcache + NYLAS_DISABLE_KEYRING=true \ + NYLAS_TEST_RATE_LIMIT_RPS=$(NYLAS_TEST_RATE_LIMIT_RPS) \ + NYLAS_TEST_RATE_LIMIT_BURST=$(NYLAS_TEST_RATE_LIMIT_BURST) \ + NYLAS_TEST_BINARY=$(CURDIR)/bin/nylas \ + go test ./internal/cli/integration/... -tags=integration -v -timeout 10m -p 1 \ + -run 'TestCLI_RPC.*$$' + @echo "✓ RPC server integration checks passed" + # Clean up test resources (virtual calendars, test grants, test events, test emails, etc.) test-cleanup: @echo "=== Cleaning up test resources ===" diff --git a/cmd/nylas/main.go b/cmd/nylas/main.go index cb73b87..af0fee6 100644 --- a/cmd/nylas/main.go +++ b/cmd/nylas/main.go @@ -23,6 +23,7 @@ import ( "github.com/nylas/cli/internal/cli/mcp" "github.com/nylas/cli/internal/cli/notetaker" "github.com/nylas/cli/internal/cli/otp" + "github.com/nylas/cli/internal/cli/rpc" "github.com/nylas/cli/internal/cli/scheduler" "github.com/nylas/cli/internal/cli/setup" "github.com/nylas/cli/internal/cli/slack" @@ -58,6 +59,7 @@ func main() { rootCmd.AddCommand(notetaker.NewNotetakerCmd()) rootCmd.AddCommand(timezone.NewTimezoneCmd()) rootCmd.AddCommand(mcp.NewMCPCmd()) + rootCmd.AddCommand(rpc.NewRPCCmd()) rootCmd.AddCommand(slack.NewSlackCmd()) rootCmd.AddCommand(templatecmd.NewTemplateCmd()) rootCmd.AddCommand(demo.NewDemoCmd()) diff --git a/docs/RPC.md b/docs/RPC.md new file mode 100644 index 0000000..2ad66ee --- /dev/null +++ b/docs/RPC.md @@ -0,0 +1,405 @@ +# Nylas CLI — RPC Server (`nylas rpc serve`) + +A local **JSON-RPC 2.0 server over WebSocket** that exposes the Nylas CLI's full +capability surface to a thin client (for example a desktop app). The CLI binary is the +engine — it holds the credentials, runs the live pollers, and owns all business logic. +Clients are thin: they send requests and render the streamed results and notifications. + +- **Endpoint:** `ws://127.0.0.1:7368/ws` +- **Protocol:** JSON-RPC 2.0 (bidirectional) over WebSocket +- **Auth:** per-session bearer token, loopback-only bind +- **Surface:** ~108 methods across 18 domains + live push notifications + +--- + +## Table of contents + +1. [Quick start](#quick-start) +2. [Architecture](#architecture) +3. [Transport & message format](#transport--message-format) +4. [Authentication & security](#authentication--security) +5. [Configuration](#configuration) +6. [Error codes](#error-codes) +7. [Method reference](#method-reference) +8. [Notifications (server → client push)](#notifications-server--client-push) +9. [Examples](#examples) +10. [Testing](#testing) +11. [Limitations & scope](#limitations--scope) + +--- + +## Quick start + +Start the server: + +```bash +nylas rpc serve # binds 127.0.0.1:7368 +nylas rpc serve --addr 127.0.0.1:9000 +``` + +On first run the server generates a session token, stores it in the OS keyring, and +prints how to authenticate. To use a known token (headless / scripting): + +```bash +NYLAS_WS_TOKEN=my-secret nylas rpc serve +``` + +Connect a WebSocket client to `ws://127.0.0.1:7368/ws` with the token, then send a request: + +```json +{ "jsonrpc": "2.0", "id": 1, "method": "email.list", "params": { "limit": 10 } } +``` + +--- + +## Architecture + +The CLI follows a hexagonal architecture (CLI → Port → Adapter). The RPC server is a +**second inbound adapter** alongside the cobra CLI, calling the same `ports.NylasClient`: + +``` +client ──ws (JSON-RPC request)──▶ rpcserver ──▶ ports.NylasClient ──▶ Nylas API +client ◀─ws (JSON-RPC notification)── rpcserver ◀── pollers (received_after / updated_after / …) + 127.0.0.1 + session token + Origin check ; creds from the OS keyring +``` + +- **No business logic in the client.** It can't run an operation itself; it only talks to the server. +- **The server owns:** credentials (keyring), the incremental pollers, session/connection state. +- **Code:** `internal/cli/rpc/` (command) and `internal/adapters/rpcserver/` (server + handlers + pollers). + +--- + +## Transport & message format + +WebSocket at path `/ws`. Messages are JSON-RPC 2.0 objects, one per WebSocket text frame. + +**Request** (has `id`): +```json +{ "jsonrpc": "2.0", "id": 1, "method": "email.list", "params": { "limit": 5 } } +``` + +**Response** (matches `id`): +```json +{ "jsonrpc": "2.0", "id": 1, "result": { "messages": [ … ], "next_cursor": "…", "has_more": true } } +``` + +**Error response:** +```json +{ "jsonrpc": "2.0", "id": 1, "error": { "code": -32602, "message": "message_id required" } } +``` + +**Notification** (server → client, no `id`): +```json +{ "jsonrpc": "2.0", "method": "message.received", "params": { "id": "…", "subject": "…" } } +``` + +Notes: +- Parse-error / invalid-request responses carry `"id": null` per the spec. +- A **notification from the client** (a request with no `id`) is executed but gets **no reply** — + used for `client.focus` (see [adaptive polling](#adaptive-polling)). +- `page_token` / `next_cursor` are **opaque** — store and replay verbatim per grant; the format + differs by provider (e.g. base64-JSON for the `nylas` provider vs numeric for Google). + +--- + +## Authentication & security + +The server holds live Nylas credentials, so the local socket is a real trust boundary. + +- **Bearer token** on the WebSocket upgrade, via either: + - `Authorization: Bearer ` header, or + - `?token=` query parameter. + Wrong/missing token → **401**. Comparison is constant-time over SHA-256 digests (no length leak). +- **Token lifecycle:** generated once (32 bytes, `crypto/rand`, base64url), persisted in the OS + keyring (`rpc_session_token`); `NYLAS_WS_TOKEN` overrides. Reused across restarts (no rotation/expiry). +- **Loopback only:** binds `127.0.0.1`. A non-loopback `--addr` is **refused** unless `--allow-remote` + is passed (then it warns). Never expose a credential-holding socket to the network unauthenticated. +- **Origin check:** non-empty `Origin` headers are rejected (blocks browser-based CSWSH / DNS-rebinding). +- **Generic client errors:** internal/upstream error detail is logged to the server's stderr; the + client receives a generic `-32603 "internal error"` (intentional RPC errors like `-32602` pass through). +- **`config.read` is whitelisted** — it never returns secrets, grants, or AI/GPG/dashboard sub-objects + (only boolean presence flags). +- **Writes execute immediately** — there is no server-side confirmation prompt. The **client** is + responsible for confirming destructive/outbound operations before calling. + +--- + +## Configuration + +| Flag / Env | Purpose | Default | +|---|---|---| +| `--addr` | bind address | `127.0.0.1:7368` | +| `NYLAS_WS_ADDR` | bind address (env; `--addr` wins) | `127.0.0.1:7368` | +| `--allow-remote` | permit a non-loopback bind (warns) | `false` | +| `NYLAS_WS_TOKEN` | inject the session token (headless/CI) | auto-generated, keyring-brokered | +| `NYLAS_DISABLE_KEYRING` | store token/creds in `~/.config/nylas` instead of the keyring | `false` | + +The server resolves the Nylas API credentials and default grant the same way the rest of the +CLI does (keyring, or env/file when `NYLAS_DISABLE_KEYRING=true`). Live pollers run only when a +**default grant** is configured; otherwise the server still serves requests and prints a notice. + +--- + +## Error codes + +Standard JSON-RPC 2.0 codes: + +| Code | Meaning | When | +|---|---|---| +| `-32700` | Parse error | malformed JSON (response `id` is `null`) | +| `-32600` | Invalid request | missing/incorrect `jsonrpc` field | +| `-32601` | Method not found | unknown method | +| `-32602` | Invalid params | missing required param / bad value (e.g. `message_id required`) | +| `-32603` | Internal error | upstream/handler failure (detail logged server-side, generic to client) | + +--- + +## Method reference + +Conventions: +- `grant_id` is **optional** on per-grant methods — it falls back to the server's default grant. + App-level methods (admin, scheduler configs, etc.) take **no** grant. +- Required ids return `-32602` when missing. +- Create/update params **embed the corresponding `domain.*Request` struct** at the top level — i.e. + the request fields sit alongside `grant_id`/ids (see `internal/domain` for exact fields). +- Delete-style methods return `{ "deleted": true }` (or `{ "revoked": true }` / `{ "cancelled": true }`). + +### Email +| Method | Params | Result | +|---|---|---| +| `email.list` | `grant_id?, limit?, page_token?, received_after?` | `{ messages, next_cursor, has_more }` | +| `email.get` | `grant_id?, message_id` | message | +| `email.send` | `grant_id?` + `SendMessageRequest` | message | +| `email.update` | `grant_id?, message_id` + `UpdateMessageRequest` | message | +| `email.delete` | `grant_id?, message_id` | `{ deleted }` | +| `email.clean` | `grant_id?` + `CleanMessagesRequest` | `{ messages }` (cleaned) | +| `email.folder.list` | `grant_id?` | `{ folders }` | +| `email.folder.get` | `grant_id?, folder_id` | folder | +| `email.folder.create` | `grant_id?` + `CreateFolderRequest` | folder | +| `email.folder.update` | `grant_id?, folder_id` + `UpdateFolderRequest` | folder | +| `email.folder.delete` | `grant_id?, folder_id` | `{ deleted }` | +| `email.attachment.list` | `grant_id?, message_id` | `{ attachments }` | +| `email.attachment.get` | `grant_id?, message_id, attachment_id` | attachment (metadata) | +| `email.attachment.download` | `grant_id?, message_id, attachment_id` | `{ content (base64), size }` | +| `email.signature.list` | `grant_id?` | `{ signatures }` | +| `email.signature.get` | `grant_id?, signature_id` | signature | +| `email.signature.create` | `grant_id?` + `CreateSignatureRequest` | signature | +| `email.signature.update` | `grant_id?, signature_id` + `UpdateSignatureRequest` | signature | +| `email.signature.delete` | `grant_id?, signature_id` | `{ deleted }` | +| `email.scheduled.list` | `grant_id?` | `{ scheduled }` | +| `email.scheduled.get` | `grant_id?, schedule_id` | scheduled message | +| `email.scheduled.cancel` | `grant_id?, schedule_id` | `{ cancelled }` | + +### Drafts +| Method | Params | Result | +|---|---|---| +| `draft.list` | `grant_id?, limit?` | `{ drafts }` | +| `draft.get` | `grant_id?, draft_id` | draft | +| `draft.create` | `grant_id?` + `CreateDraftRequest` | draft | +| `draft.update` | `grant_id?, draft_id` + `CreateDraftRequest` | draft | +| `draft.delete` | `grant_id?, draft_id` | `{ deleted }` | +| `draft.send` | `grant_id?, draft_id` + `SendDraftRequest` | message | + +### Threads +| Method | Params | Result | +|---|---|---| +| `thread.list` | `grant_id?, limit?, page_token?, latest_message_after?, unread?` | `{ threads, next_cursor, has_more }` | +| `thread.get` | `grant_id?, thread_id` | thread | +| `thread.update` | `grant_id?, thread_id` + `UpdateMessageRequest` (unread/starred/folders) | thread | +| `thread.delete` | `grant_id?, thread_id` | `{ deleted }` | + +### Calendar & events +| Method | Params | Result | +|---|---|---| +| `calendar.list` | `grant_id?` | `{ calendars }` | +| `event.list` | `grant_id?, calendar_id=primary, limit?, page_token?, updated_after?, start?, end?` | `{ events, next_cursor, has_more }` | +| `event.get` | `grant_id?, calendar_id=primary, event_id` | event | +| `event.create` | `grant_id?, calendar_id=primary` + `CreateEventRequest` | event | +| `event.update` | `grant_id?, calendar_id=primary, event_id` + `UpdateEventRequest` | event | +| `event.delete` | `grant_id?, calendar_id=primary, event_id` | `{ deleted }` | +| `event.rsvp` | `grant_id?, calendar_id=primary, event_id` + `SendRSVPRequest` | `{ ok }` | +| `event.import` | `grant_id?` + `EventQueryParams` (incl. `calendar_id`, `start`, `end`) | `{ events }` | +| `event.recurring.list` | `grant_id?, calendar_id=primary, master_event_id` + `EventQueryParams` | `{ events }` (instances) | +| `event.recurring.update` | `grant_id?, calendar_id=primary, event_id` + `UpdateEventRequest` | event | +| `event.recurring.delete` | `grant_id?, calendar_id=primary, event_id` | `{ deleted }` | +| `calendar.get` | `grant_id?, calendar_id` | calendar | +| `calendar.create` | `grant_id?` + `CreateCalendarRequest` | calendar | +| `calendar.update` | `grant_id?, calendar_id` + `UpdateCalendarRequest` | calendar | +| `calendar.delete` | `grant_id?, calendar_id` | `{ deleted }` | +| `calendar.freeBusy` | `grant_id?` + `FreeBusyRequest` | free/busy response | +| `calendar.availability` | `AvailabilityRequest` (**no grant**) | availability response | +| `calendar.resources` | `grant_id?` | `{ resources }` (bookable rooms) | +| `calendar.virtual.create` | `email` (**no grant**) | virtual calendar grant | +| `calendar.virtual.list` | — (**no grant**) | `{ grants }` | +| `calendar.virtual.get` | `grant_id` (the virtual grant id) | virtual calendar grant | +| `calendar.virtual.delete` | `grant_id` (the virtual grant id) | `{ deleted }` | + +### Contacts +| Method | Params | Result | +|---|---|---| +| `contact.list` | `grant_id?, limit?, page_token?` | `{ contacts, next_cursor, has_more }` | +| `contact.get` | `grant_id?, contact_id` | contact | +| `contact.create` | `grant_id?` + `CreateContactRequest` | contact | +| `contact.update` | `grant_id?, contact_id` + `UpdateContactRequest` | contact | +| `contact.delete` | `grant_id?, contact_id` | `{ deleted }` | +| `contact.getWithPicture` | `grant_id?, contact_id, include_picture?` | contact (with base64 `picture` when requested) | +| `contact.group.list` | `grant_id?` | `{ groups }` | +| `contact.group.get` | `grant_id?, group_id` | contact group | +| `contact.group.create` | `grant_id?` + `CreateContactGroupRequest` | contact group | +| `contact.group.update` | `grant_id?, group_id` + `UpdateContactGroupRequest` | contact group | +| `contact.group.delete` | `grant_id?, group_id` | `{ deleted }` | + +### Agent accounts / grants / config +| Method | Params | Result | +|---|---|---| +| `agentAccount.list` | — | `{ accounts }` | +| `agentAccount.get` | `grant_id` (the agent account's grant) | account | +| `grant.list` | — | `{ grants }` (local store: id/email/provider) | +| `config.read` | — | whitelisted config (region, default_grant, callback_port, tui_theme, api{base_url,timeout}, working_hours, ai/gpg/dashboard `*_configured` booleans). **No secrets.** | + +### Notetaker +`notetaker.list` (`grant_id?` + query) · `notetaker.get` · `notetaker.create` · `notetaker.update` +· `notetaker.delete` → `{ deleted }` · `notetaker.leave` → `{ left }` · `notetaker.media` +(all per-grant; `notetaker_id` required where applicable). + +### Scheduler +- Configs: `scheduler.config.list` / `.get` / `.create` / `.update` / `.delete` +- Sessions: `scheduler.session.create` / `.get` +- Bookings: `scheduler.booking.get` / `.confirm` / `.reschedule` / `.cancel` (`{ cancelled }`) +- Group events (per-grant): `scheduler.groupEvent.list` (requires `config_id`, `calendar_id`, + `start_time`, `end_time`) / `.create` / `.update` / `.delete` / `.import` + +### Templates & workflows +A `scope` param selects `"app"` (default) or `"grant"`; only `"grant"` scope requires a grant. +- Templates: `template.list` / `.get` / `.create` / `.update` / `.delete` / `.render` / `.renderHTML` +- Workflows: `workflow.list` / `.get` / `.create` / `.update` / `.delete` + +### Admin & workspaces (app-level; no grant) +- Applications: `admin.app.list` / `.get` / `.create` / `.update` / `.delete` +- Callback URIs: `admin.callbackUri.list` / `.get` / `.create` / `.update` / `.delete` +- Connectors: `admin.connector.list` / `.get` / `.create` / `.update` / `.delete` +- **Credentials (secret material):** `admin.credential.list` / `.get` / `.create` / `.update` / `.delete` +- Workspaces: `workspace.list` / `.get` / `.create` / `.update` / `.delete` / `workspace.assignGrants` +- Grants admin: `admin.grants.listAll` / `admin.grants.stats` + +### Auth +| Method | Params | Result | +|---|---|---| +| `auth.grant.get` | `grant_id?` | grant | +| `auth.grant.revoke` | `grant_id?` | `{ revoked }` | +| `auth.grant.createCustom` | `provider, settings` | grant | +| `auth.url` | `provider, redirect_uri, state?, code_challenge?` | `{ url }` (pure builder; no API call) | +| `auth.grant.exchange` | `code, redirect_uri, code_verifier?` | grant (completes the OAuth code→grant round-trip) | + +> The interactive OAuth login flow (opening a browser + running a local callback server) is **not** +> exposed over RPC. A GUI runs its own redirect, then calls `auth.url` → `auth.grant.exchange`. +> Local CLI session commands (`whoami`, `switch`, `token`, `status`) are intentionally CLI-only. + +### Audit (local audit log) +`audit.list` (`limit?`) · `audit.query` (`AuditQueryOptions`) · `audit.summary` (`days?`) · +`audit.stats` · `audit.config.read` · `audit.config.save` (`{ ok }`) · `audit.path` · +`audit.clear` (`{ cleared }`) · `audit.cleanup` (`{ ok }`). + +### OTP +| Method | Params | Result | +|---|---|---| +| `otp.get` | `email?` (omit → default grant) | `OTPResult` (code/from/subject/received/message_id) | + +--- + +## Notifications (server → client push) + +When a default grant is configured, the server runs incremental pollers and pushes notifications +(no `id`) to all connected clients: + +| Method | Fires when | +|---|---| +| `message.received` | a new message arrives | +| `thread.updated` | a thread has new activity | +| `event.updated` | a calendar event is created or edited (per calendar) | +| `contact.updated` | a contact is created or its content changes (SHA-256 fingerprint diff) | +| `contact.deleted` | a contact disappears from the address book | + +Polling cursors: messages use `received_after`, threads `latest_message_after`, events +`updated_after`; contacts have no server-side time filter so the poller refetches and diffs on a +content fingerprint. Filters are **exclusive**, so pollers query `cursor-1` and dedupe boundary +records by id. + +### Adaptive polling + +Send a `client.focus` **notification** (no `id`) to scale the poll interval: + +```json +{ "jsonrpc": "2.0", "method": "client.focus", "params": { "focused": true } } +``` + +- `focused: true` → fast interval (5s) for message/thread/event pollers. +- `focused: false` → idle interval (30s). Contacts always poll on a slow 60s cadence. + +--- + +## Examples + +Authenticate, list mail, and subscribe to live notifications (pseudocode): + +```js +const ws = new WebSocket("ws://127.0.0.1:7368/ws", { + headers: { Authorization: `Bearer ${token}` }, +}); + +// request / response +ws.send(JSON.stringify({ jsonrpc: "2.0", id: 1, method: "email.list", params: { limit: 10 } })); + +// tell the server we're focused → faster polling +ws.send(JSON.stringify({ jsonrpc: "2.0", method: "client.focus", params: { focused: true } })); + +ws.onmessage = (e) => { + const msg = JSON.parse(e.data); + if (msg.id === 1) console.log("emails:", msg.result.messages); + else if (msg.method === "message.received") console.log("new mail:", msg.params.subject); +}; +``` + +Send an email (the client confirms first; the server executes immediately): + +```json +{ "jsonrpc": "2.0", "id": 2, "method": "email.send", + "params": { "to": [{ "email": "a@b.com" }], "subject": "Hi", "body": "Hello" } } +``` + +--- + +## Testing + +| Layer | What | +|---|---| +| Unit (`-race`) | every handler, all pollers (boundary/truncation/fingerprint/deleted/seed), dispatcher, auth, adaptive intervals, WebSocket concurrency | +| Adapter (httptest) | query-building (cursors, filters) | +| Integration (live API) | `make test-integration-rpc` — protocol/auth edge cases, all reads, reversible write round-trips (draft/contact/event create→delete), a live `message.received` round-trip, extended-domain reads | + +```bash +make test-integration-rpc # requires NYLAS_API_KEY + NYLAS_GRANT_ID +``` + +Integration tests live in `internal/cli/integration/rpc_*_test.go` (build tag `integration`). + +--- + +## Limitations & scope + +- **Dashboard is not exposed** — its auth is interactive (login/MFA/SSO) and its app/API-key ops + return secret material; that needs a separate, dedicated design. +- **Extended-domain writes** (admin/scheduler/template/workflow/notetaker create/update/delete) are + **unit-tested only**, not live-integration-tested — exercising them creates real app-level or + secret resources, which isn't safe in a test suite. Their reads are integration-verified. +- **CLI-only conveniences are not exposed:** GPG sign/encrypt, hosted-template rendering shortcuts, + recipient-string parsing, raw-MIME send, attachment-from-path. +- **No write-path rate limiting** and **no distinct error codes** for not-found vs rate-limited + (everything upstream maps to `-32603`) — acceptable within the loopback+token threat model; both + are candidate follow-ups. +- **Token has no rotation/expiry** — rotate by deleting the keyring entry (or changing + `NYLAS_WS_TOKEN`) and restarting. + +--- + +*Code: `internal/cli/rpc/`, `internal/adapters/rpcserver/`. Tracked in Jira TW-5722 (epic TW-5721).* diff --git a/docs/commands/agent-rule.md b/docs/commands/agent-rule.md index 53935f1..5b661cc 100644 --- a/docs/commands/agent-rule.md +++ b/docs/commands/agent-rule.md @@ -70,6 +70,15 @@ nylas agent rule create \ --action archive ``` +```bash +nylas agent rule create \ + --name "File receipts" \ + --condition from.domain,is,billing.example.com \ + --action assign_to_folder=Receipts +``` + +The `assign_to_folder` value is a folder **name** — a custom folder's name (use its full path for a nested folder, e.g. `Clients/Acme`) or a system folder name (`Inbox`, `Sent`, `Drafts`, `Trash`, `Junk`, `Archive`). The name is resolved when the rule runs, so a reference to a folder that doesn't exist is skipped. + Available common flags: - `--name` diff --git a/go.mod b/go.mod index 2687a2a..b313b6d 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/fatih/color v1.18.0 github.com/gdamore/tcell/v2 v2.13.4 github.com/google/uuid v1.6.0 + github.com/gorilla/websocket v1.5.3 github.com/ncruces/go-sqlite3 v0.30.4 github.com/rivo/tview v0.42.0 github.com/slack-go/slack v0.23.1 @@ -47,7 +48,6 @@ require ( github.com/dustin/go-humanize v1.0.1 // indirect github.com/gdamore/encoding v1.0.1 // indirect github.com/godbus/dbus/v5 v5.2.1 // indirect - github.com/gorilla/websocket v1.5.3 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/lucasb-eyer/go-colorful v1.4.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect diff --git a/internal/adapters/nylas/calendars_events.go b/internal/adapters/nylas/calendars_events.go index 5b40b0e..e0891cb 100644 --- a/internal/adapters/nylas/calendars_events.go +++ b/internal/adapters/nylas/calendars_events.go @@ -46,6 +46,8 @@ func (c *HTTPClient) GetEventsWithCursor(ctx context.Context, grantID, calendarI Add("page_token", params.PageToken). AddInt64("start", params.Start). AddInt64("end", params.End). + AddInt64("updated_after", params.UpdatedAfter). + AddInt64("updated_before", params.UpdatedBefore). Add("title", params.Title). Add("location", params.Location). AddBool("show_cancelled", params.ShowCancelled). diff --git a/internal/adapters/nylas/calendars_events_test.go b/internal/adapters/nylas/calendars_events_test.go index ea0fb69..a0a4298 100644 --- a/internal/adapters/nylas/calendars_events_test.go +++ b/internal/adapters/nylas/calendars_events_test.go @@ -180,6 +180,14 @@ func TestHTTPClient_GetEventsWithCursor(t *testing.T) { }, wantQueryKeys: []string{"page_token"}, }, + { + name: "includes updated_after filter", + params: &domain.EventQueryParams{ + Limit: 10, + UpdatedAfter: 1710000000, + }, + wantQueryKeys: []string{"updated_after"}, + }, { // ical_uid is the bridge between an emailed invite and a Nylas // event ID; the RSVP handler relies on the upstream filter so diff --git a/internal/adapters/nylas/contacts.go b/internal/adapters/nylas/contacts.go index f5ce062..a416d13 100644 --- a/internal/adapters/nylas/contacts.go +++ b/internal/adapters/nylas/contacts.go @@ -27,6 +27,7 @@ type contactResponse struct { Notes string `json:"notes"` PictureURL string `json:"picture_url"` Picture string `json:"picture"` + UpdatedAt int64 `json:"updated_at"` Emails []domain.ContactEmail `json:"emails"` PhoneNumbers []domain.ContactPhone `json:"phone_numbers"` WebPages []domain.ContactWebPage `json:"web_pages"` @@ -248,6 +249,7 @@ func convertContact(c contactResponse) domain.Contact { Notes: c.Notes, PictureURL: c.PictureURL, Picture: c.Picture, + UpdatedAt: c.UpdatedAt, Emails: c.Emails, PhoneNumbers: c.PhoneNumbers, WebPages: c.WebPages, diff --git a/internal/adapters/nylas/contacts_internal_test.go b/internal/adapters/nylas/contacts_internal_test.go new file mode 100644 index 0000000..74a6477 --- /dev/null +++ b/internal/adapters/nylas/contacts_internal_test.go @@ -0,0 +1,17 @@ +//go:build !integration +// +build !integration + +package nylas + +import "testing" + +func TestConvertContactIncludesUpdatedAt(t *testing.T) { + contact := convertContact(contactResponse{ + ID: "contact-1", + UpdatedAt: 1700000000, + }) + + if contact.UpdatedAt != 1700000000 { + t.Fatalf("UpdatedAt = %d, want 1700000000", contact.UpdatedAt) + } +} diff --git a/internal/adapters/nylas/demo_threads.go b/internal/adapters/nylas/demo_threads.go index a8b8957..ef00b03 100644 --- a/internal/adapters/nylas/demo_threads.go +++ b/internal/adapters/nylas/demo_threads.go @@ -12,6 +12,13 @@ func (d *DemoClient) GetThreads(ctx context.Context, grantID string, params *dom return d.getDemoThreads(), nil } +// GetThreadsWithCursor returns demo threads with pagination. +func (d *DemoClient) GetThreadsWithCursor(ctx context.Context, grantID string, params *domain.ThreadQueryParams) (*domain.ThreadListResponse, error) { + return &domain.ThreadListResponse{ + Data: d.getDemoThreads(), + }, nil +} + func (d *DemoClient) getDemoThreads() []domain.Thread { now := time.Now() return []domain.Thread{ diff --git a/internal/adapters/nylas/mock_client.go b/internal/adapters/nylas/mock_client.go index 94839e8..2003dfa 100644 --- a/internal/adapters/nylas/mock_client.go +++ b/internal/adapters/nylas/mock_client.go @@ -111,6 +111,7 @@ type MockClient struct { UpdateMessageFunc func(ctx context.Context, grantID, messageID string, req *domain.UpdateMessageRequest) (*domain.Message, error) DeleteMessageFunc func(ctx context.Context, grantID, messageID string) error GetThreadsFunc func(ctx context.Context, grantID string, params *domain.ThreadQueryParams) ([]domain.Thread, error) + GetThreadsWithCursorFunc func(ctx context.Context, grantID string, params *domain.ThreadQueryParams) (*domain.ThreadListResponse, error) GetThreadFunc func(ctx context.Context, grantID, threadID string) (*domain.Thread, error) UpdateThreadFunc func(ctx context.Context, grantID, threadID string, req *domain.UpdateMessageRequest) (*domain.Thread, error) DeleteThreadFunc func(ctx context.Context, grantID, threadID string) error diff --git a/internal/adapters/nylas/mock_threads.go b/internal/adapters/nylas/mock_threads.go index 62a948b..2f3a0db 100644 --- a/internal/adapters/nylas/mock_threads.go +++ b/internal/adapters/nylas/mock_threads.go @@ -15,6 +15,20 @@ func (m *MockClient) GetThreads(ctx context.Context, grantID string, params *dom return []domain.Thread{}, nil } +// GetThreadsWithCursor retrieves threads with pagination cursor support. +func (m *MockClient) GetThreadsWithCursor(ctx context.Context, grantID string, params *domain.ThreadQueryParams) (*domain.ThreadListResponse, error) { + m.GetThreadsCalled = true + m.LastGrantID = grantID + if m.GetThreadsWithCursorFunc != nil { + return m.GetThreadsWithCursorFunc(ctx, grantID, params) + } + if m.GetThreadsFunc != nil { + threads, err := m.GetThreadsFunc(ctx, grantID, params) + return &domain.ThreadListResponse{Data: threads}, err + } + return &domain.ThreadListResponse{Data: []domain.Thread{}}, nil +} + // GetThread retrieves a single thread. func (m *MockClient) GetThread(ctx context.Context, grantID, threadID string) (*domain.Thread, error) { m.GetThreadCalled = true diff --git a/internal/adapters/nylas/threads.go b/internal/adapters/nylas/threads.go index 7dee809..cde95b0 100644 --- a/internal/adapters/nylas/threads.go +++ b/internal/adapters/nylas/threads.go @@ -34,6 +34,18 @@ type threadResponse struct { // GetThreads retrieves threads with query parameters. func (c *HTTPClient) GetThreads(ctx context.Context, grantID string, params *domain.ThreadQueryParams) ([]domain.Thread, error) { + resp, err := c.GetThreadsWithCursor(ctx, grantID, params) + if err != nil { + return nil, err + } + return resp.Data, nil +} + +// GetThreadsWithCursor retrieves threads with pagination cursor support. +func (c *HTTPClient) GetThreadsWithCursor(ctx context.Context, grantID string, params *domain.ThreadQueryParams) (*domain.ThreadListResponse, error) { + if err := validateRequired("grant ID", grantID); err != nil { + return nil, err + } if params == nil { params = &domain.ThreadQueryParams{Limit: 10} } @@ -45,23 +57,34 @@ func (c *HTTPClient) GetThreads(ctx context.Context, grantID string, params *dom queryURL := NewQueryBuilder(). AddInt("limit", params.Limit). AddInt("offset", params.Offset). + Add("page_token", params.PageToken). Add("subject", params.Subject). Add("from", params.From). Add("to", params.To). AddBoolPtr("unread", params.Unread). AddBoolPtr("starred", params.Starred). + AddInt64("latest_message_before", params.LatestMsgBefore). + AddInt64("latest_message_after", params.LatestMsgAfter). Add("q", params.SearchQuery). AddSlice("in", params.In). BuildURL(baseURL) var result struct { - Data []threadResponse `json:"data"` + Data []threadResponse `json:"data"` + NextCursor string `json:"next_cursor,omitempty"` + RequestID string `json:"request_id,omitempty"` } if err := c.doGet(ctx, queryURL, &result); err != nil { return nil, err } - return convertThreads(result.Data), nil + return &domain.ThreadListResponse{ + Data: convertThreads(result.Data), + Pagination: domain.Pagination{ + NextCursor: result.NextCursor, + HasMore: result.NextCursor != "", + }, + }, nil } // GetThread retrieves a single thread by ID. diff --git a/internal/adapters/nylas/threads_test.go b/internal/adapters/nylas/threads_test.go index 965ef57..19b0bfd 100644 --- a/internal/adapters/nylas/threads_test.go +++ b/internal/adapters/nylas/threads_test.go @@ -213,11 +213,13 @@ func TestHTTPClient_GetThreads_WithFilters(t *testing.T) { // Check query params assert.Equal(t, "20", r.URL.Query().Get("limit")) assert.Equal(t, "5", r.URL.Query().Get("offset")) + assert.Equal(t, "cursor-xyz", r.URL.Query().Get("page_token")) assert.Equal(t, "Important", r.URL.Query().Get("subject")) assert.Equal(t, "alice@example.com", r.URL.Query().Get("from")) assert.Equal(t, "bob@example.com", r.URL.Query().Get("to")) assert.Equal(t, "true", r.URL.Query().Get("unread")) assert.Equal(t, "false", r.URL.Query().Get("starred")) + assert.Equal(t, "1700000000", r.URL.Query().Get("latest_message_after")) assert.Equal(t, "project X", r.URL.Query().Get("q")) response := map[string]any{ @@ -237,14 +239,16 @@ func TestHTTPClient_GetThreads_WithFilters(t *testing.T) { unread := true starred := false params := &domain.ThreadQueryParams{ - Limit: 20, - Offset: 5, - Subject: "Important", - From: "alice@example.com", - To: "bob@example.com", - Unread: &unread, - Starred: &starred, - SearchQuery: "project X", + Limit: 20, + Offset: 5, + PageToken: "cursor-xyz", + Subject: "Important", + From: "alice@example.com", + To: "bob@example.com", + Unread: &unread, + Starred: &starred, + LatestMsgAfter: 1700000000, + SearchQuery: "project X", } threads, err := client.GetThreads(ctx, "grant-filter", params) @@ -252,6 +256,46 @@ func TestHTTPClient_GetThreads_WithFilters(t *testing.T) { assert.Len(t, threads, 0) } +func TestHTTPClient_GetThreadsWithCursor(t *testing.T) { + now := time.Now().Unix() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/v3/grants/grant-123/threads", r.URL.Path) + assert.Equal(t, "cursor-1", r.URL.Query().Get("page_token")) + + response := map[string]any{ + "data": []map[string]any{ + { + "id": "thread-1", + "grant_id": "grant-123", + "subject": "First", + "earliest_message_date": now, + "latest_message_received_date": now, + "latest_message_sent_date": now, + }, + }, + "next_cursor": "cursor-2", + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client := NewHTTPClient() + client.SetCredentials("client-id", "secret", "api-key") + client.SetBaseURL(server.URL) + + result, err := client.GetThreadsWithCursor(context.Background(), "grant-123", &domain.ThreadQueryParams{ + Limit: 1, + PageToken: "cursor-1", + }) + + require.NoError(t, err) + require.Len(t, result.Data, 1) + assert.Equal(t, "thread-1", result.Data[0].ID) + assert.Equal(t, "cursor-2", result.Pagination.NextCursor) + assert.True(t, result.Pagination.HasMore) +} + func TestHTTPClient_GetThread(t *testing.T) { now := time.Now().Unix() server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/internal/adapters/rpcserver/auth.go b/internal/adapters/rpcserver/auth.go new file mode 100644 index 0000000..5fce032 --- /dev/null +++ b/internal/adapters/rpcserver/auth.go @@ -0,0 +1,93 @@ +package rpcserver + +import ( + "crypto/rand" + "crypto/sha256" + "crypto/subtle" + "encoding/base64" + "errors" + "fmt" + "net" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +const ( + // KeyRPCSessionToken is the SecretStore key for the brokered token. + KeyRPCSessionToken = "rpc_session_token" + // EnvWSToken overrides the stored token for headless or file-backed setups. + EnvWSToken = "NYLAS_WS_TOKEN" +) + +// GenerateToken returns a cryptographically random URL-safe token. +func GenerateToken() (string, error) { + token := make([]byte, 32) + if _, err := rand.Read(token); err != nil { + return "", fmt.Errorf("generate rpc session token: %w", err) + } + return base64.RawURLEncoding.EncodeToString(token), nil +} + +// ResolveToken returns the session token from env, storage, or a newly persisted token. +func ResolveToken(store ports.SecretStore, getenv func(string) string) (string, error) { + if token := getenv(EnvWSToken); token != "" { + return token, nil + } + + token, err := store.Get(KeyRPCSessionToken) + if err != nil && !errors.Is(err, domain.ErrSecretNotFound) { + return "", fmt.Errorf("get rpc session token: %w", err) + } + if token != "" { + return token, nil + } + + token, err = GenerateToken() + if err != nil { + return "", err + } + if err := store.Set(KeyRPCSessionToken, token); err != nil { + return "", fmt.Errorf("set rpc session token: %w", err) + } + return token, nil +} + +// ValidateToken does a constant-time comparison. Empty tokens are rejected. +// Both tokens are hashed to a fixed-length digest first so the comparison does +// not leak the token length via timing (ConstantTimeCompare returns early when +// the inputs differ in length). +func ValidateToken(expected, provided string) bool { + if expected == "" || provided == "" { + return false + } + he := sha256.Sum256([]byte(expected)) + hp := sha256.Sum256([]byte(provided)) + return subtle.ConstantTimeCompare(he[:], hp[:]) == 1 +} + +// ValidateOrigin returns true if origin is allowed. Empty origin is allowed for non-browser clients; the token is the gate. +func ValidateOrigin(origin string, allowed []string) bool { + if origin == "" { + return true + } + for _, candidate := range allowed { + if origin == candidate { + return true + } + } + return false +} + +// IsLoopback reports whether a bind address's host is loopback. +func IsLoopback(addr string) (bool, error) { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return false, fmt.Errorf("parse bind address: %w", err) + } + if host == "localhost" { + return true, nil + } + ip := net.ParseIP(host) + return ip != nil && ip.IsLoopback(), nil +} diff --git a/internal/adapters/rpcserver/auth_test.go b/internal/adapters/rpcserver/auth_test.go new file mode 100644 index 0000000..3bdbd77 --- /dev/null +++ b/internal/adapters/rpcserver/auth_test.go @@ -0,0 +1,244 @@ +package rpcserver + +import ( + "errors" + "strings" + "testing" + + "github.com/nylas/cli/internal/domain" +) + +type fakeSecretStore struct { + secrets map[string]string + getErr error + setErr error + setKey string + setVal string +} + +func newFakeSecretStore() *fakeSecretStore { + return &fakeSecretStore{secrets: make(map[string]string)} +} + +func (f *fakeSecretStore) Set(key, value string) error { + if f.setErr != nil { + return f.setErr + } + f.setKey = key + f.setVal = value + f.secrets[key] = value + return nil +} + +func (f *fakeSecretStore) Get(key string) (string, error) { + if f.getErr != nil { + if errors.Is(f.getErr, domain.ErrSecretNotFound) && f.secrets[key] != "" { + return f.secrets[key], nil + } + return "", f.getErr + } + return f.secrets[key], nil +} + +func (f *fakeSecretStore) Delete(key string) error { + delete(f.secrets, key) + return nil +} + +func (f *fakeSecretStore) IsAvailable() bool { return true } + +func (f *fakeSecretStore) Name() string { return "fake" } + +func TestGenerateToken(t *testing.T) { + first, err := GenerateToken() + if err != nil { + t.Fatalf("GenerateToken() error = %v", err) + } + second, err := GenerateToken() + if err != nil { + t.Fatalf("GenerateToken() second error = %v", err) + } + + if first == "" { + t.Fatal("GenerateToken() returned empty token") + } + if first == second { + t.Fatal("GenerateToken() returned the same token twice") + } + if len(first) != 43 { + t.Fatalf("GenerateToken() length = %d, want 43", len(first)) + } + if strings.ContainsAny(first, "+/=") { + t.Fatalf("GenerateToken() = %q, want URL-safe token without padding", first) + } +} + +func TestResolveToken(t *testing.T) { + storeErr := errors.New("store unavailable") + + tests := []struct { + name string + store *fakeSecretStore + getenv func(string) string + want string + wantErr error + wantSet bool + afterGet bool + }{ + { + name: "env token wins", + store: &fakeSecretStore{getErr: storeErr}, + getenv: func(key string) string { + if key == EnvWSToken { + return "env-token" + } + return "" + }, + want: "env-token", + }, + { + name: "store token returned", + store: &fakeSecretStore{secrets: map[string]string{KeyRPCSessionToken: "stored-token"}}, + getenv: func(string) string { + return "" + }, + want: "stored-token", + }, + { + name: "empty store generates and persists", + store: newFakeSecretStore(), + getenv: func(string) string { return "" }, + wantSet: true, + afterGet: true, + }, + { + name: "missing store token generates and persists", + store: &fakeSecretStore{secrets: make(map[string]string), getErr: domain.ErrSecretNotFound}, + getenv: func(string) string { return "" }, + wantSet: true, + afterGet: true, + }, + { + name: "store error propagates", + store: &fakeSecretStore{getErr: storeErr}, + getenv: func(string) string { + return "" + }, + wantErr: storeErr, + }, + { + name: "store set error propagates", + store: &fakeSecretStore{secrets: make(map[string]string), setErr: storeErr}, + getenv: func(string) string { return "" }, + wantErr: storeErr, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ResolveToken(tt.store, tt.getenv) + if tt.wantErr != nil { + if !errors.Is(err, tt.wantErr) { + t.Fatalf("ResolveToken() error = %v, want wrapping %v", err, tt.wantErr) + } + return + } + if err != nil { + t.Fatalf("ResolveToken() error = %v", err) + } + if tt.want != "" && got != tt.want { + t.Fatalf("ResolveToken() = %q, want %q", got, tt.want) + } + if tt.wantSet && (tt.store.setKey != KeyRPCSessionToken || tt.store.setVal == "") { + t.Fatalf("ResolveToken() persisted key/value = %q/%q, want key %q and non-empty value", tt.store.setKey, tt.store.setVal, KeyRPCSessionToken) + } + if tt.afterGet { + stored, err := tt.store.Get(KeyRPCSessionToken) + if err != nil { + t.Fatalf("store.Get() error = %v", err) + } + if stored != got { + t.Fatalf("stored token = %q, want generated token %q", stored, got) + } + } + }) + } +} + +func TestValidateToken(t *testing.T) { + tests := []struct { + name string + expected string + provided string + want bool + }{ + {name: "match", expected: "abc123", provided: "abc123", want: true}, + {name: "equal length mismatch exercises subtle compare path", expected: "abc123", provided: "abc124", want: false}, + {name: "different length mismatch", expected: "abc123", provided: "abc1234", want: false}, + {name: "empty expected", expected: "", provided: "abc123", want: false}, + {name: "empty provided", expected: "abc123", provided: "", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ValidateToken(tt.expected, tt.provided); got != tt.want { + t.Fatalf("ValidateToken() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestValidateOrigin(t *testing.T) { + allowed := []string{"http://localhost:3000", "https://app.example.com"} + tests := []struct { + name string + origin string + want bool + }{ + {name: "allow-list hit", origin: "http://localhost:3000", want: true}, + {name: "allow-list miss", origin: "http://localhost:3001", want: false}, + {name: "empty origin allowed", origin: "", want: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ValidateOrigin(tt.origin, allowed); got != tt.want { + t.Fatalf("ValidateOrigin() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIsLoopback(t *testing.T) { + tests := []struct { + name string + addr string + want bool + wantErr bool + }{ + {name: "ipv4 loopback", addr: "127.0.0.1:8080", want: true}, + {name: "ipv6 loopback", addr: "[::1]:8080", want: true}, + {name: "localhost", addr: "localhost:0", want: true}, + {name: "unspecified ipv4", addr: "0.0.0.0:8080", want: false}, + {name: "public ip", addr: "8.8.8.8:8080", want: false}, + {name: "garbage", addr: "garbage", wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := IsLoopback(tt.addr) + if tt.wantErr { + if err == nil { + t.Fatal("IsLoopback() error = nil, want error") + } + return + } + if err != nil { + t.Fatalf("IsLoopback() error = %v", err) + } + if got != tt.want { + t.Fatalf("IsLoopback() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/adapters/rpcserver/handlers.go b/internal/adapters/rpcserver/handlers.go new file mode 100644 index 0000000..f192cab --- /dev/null +++ b/internal/adapters/rpcserver/handlers.go @@ -0,0 +1,27 @@ +package rpcserver + +import "encoding/json" + +func resolveGrant(grantID, defaultGrant string) (string, error) { + if grantID != "" { + return grantID, nil + } + if defaultGrant != "" { + return defaultGrant, nil + } + return "", NewRPCError(InvalidParams, "grant_id required", nil) +} + +func decodeParams(params json.RawMessage, v any) error { + if len(params) == 0 { + return nil + } + if err := json.Unmarshal(params, v); err != nil { + return NewRPCError(InvalidParams, "invalid params", err.Error()) + } + return nil +} + +type deletedResult struct { + Deleted bool `json:"deleted"` +} diff --git a/internal/adapters/rpcserver/handlers_admin.go b/internal/adapters/rpcserver/handlers_admin.go new file mode 100644 index 0000000..4981f1b --- /dev/null +++ b/internal/adapters/rpcserver/handlers_admin.go @@ -0,0 +1,523 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type appListResult struct { + Applications []domain.Application `json:"applications"` +} + +type appGetParams struct { + AppID string `json:"app_id"` +} + +type appCreateParams struct { + domain.CreateApplicationRequest +} + +type appUpdateParams struct { + AppID string `json:"app_id"` + domain.UpdateApplicationRequest +} + +type appDeleteParams struct { + AppID string `json:"app_id"` +} + +type callbackURIListResult struct { + CallbackURIs []domain.CallbackURI `json:"callback_uris"` +} + +type callbackURIGetParams struct { + URIID string `json:"uri_id"` +} + +type callbackURICreateParams struct { + domain.CreateCallbackURIRequest +} + +type callbackURIUpdateParams struct { + URIID string `json:"uri_id"` + domain.UpdateCallbackURIRequest +} + +type callbackURIDeleteParams struct { + URIID string `json:"uri_id"` +} + +type connectorListResult struct { + Connectors []domain.Connector `json:"connectors"` +} + +type connectorGetParams struct { + ConnectorID string `json:"connector_id"` +} + +type connectorCreateParams struct { + domain.CreateConnectorRequest +} + +type connectorUpdateParams struct { + ConnectorID string `json:"connector_id"` + domain.UpdateConnectorRequest +} + +type connectorDeleteParams struct { + ConnectorID string `json:"connector_id"` +} + +type credentialListParams struct { + ConnectorID string `json:"connector_id"` +} + +type credentialListResult struct { + Credentials []domain.ConnectorCredential `json:"credentials"` +} + +type credentialGetParams struct { + CredentialID string `json:"credential_id"` +} + +type credentialCreateParams struct { + ConnectorID string `json:"connector_id"` + domain.CreateCredentialRequest +} + +type credentialUpdateParams struct { + CredentialID string `json:"credential_id"` + domain.UpdateCredentialRequest +} + +type credentialDeleteParams struct { + CredentialID string `json:"credential_id"` +} + +type workspaceListResult struct { + Workspaces []domain.Workspace `json:"workspaces"` +} + +type workspaceGetParams struct { + WorkspaceID string `json:"workspace_id"` +} + +type workspaceCreateParams struct { + domain.CreateWorkspaceRequest +} + +type workspaceUpdateParams struct { + WorkspaceID string `json:"workspace_id"` + domain.UpdateWorkspaceRequest +} + +type workspaceDeleteParams struct { + WorkspaceID string `json:"workspace_id"` +} + +type workspaceAssignParams struct { + WorkspaceID string `json:"workspace_id"` + domain.WorkspaceAssignRequest +} + +type grantsListAllParams struct { + domain.GrantsQueryParams +} + +type grantsListAllResult struct { + Grants []domain.Grant `json:"grants"` +} + +func RegisterAdminHandlers(d *Dispatcher, client ports.AdminClient) { + d.Register("admin.app.list", func(ctx context.Context, params json.RawMessage) (any, error) { + apps, err := client.ListApplications(ctx) + if err != nil { + return nil, fmt.Errorf("admin.app.list: %w", err) + } + return appListResult{Applications: apps}, nil + }) + + d.Register("admin.app.get", func(ctx context.Context, params json.RawMessage) (any, error) { + var p appGetParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.AppID == "" { + return nil, NewRPCError(InvalidParams, "app_id required", nil) + } + + app, err := client.GetApplication(ctx, p.AppID) + if err != nil { + return nil, fmt.Errorf("admin.app.get: %w", err) + } + return app, nil + }) + + d.Register("admin.app.create", func(ctx context.Context, params json.RawMessage) (any, error) { + var p appCreateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + app, err := client.CreateApplication(ctx, &p.CreateApplicationRequest) + if err != nil { + return nil, fmt.Errorf("admin.app.create: %w", err) + } + return app, nil + }) + + d.Register("admin.app.update", func(ctx context.Context, params json.RawMessage) (any, error) { + var p appUpdateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.AppID == "" { + return nil, NewRPCError(InvalidParams, "app_id required", nil) + } + + app, err := client.UpdateApplication(ctx, p.AppID, &p.UpdateApplicationRequest) + if err != nil { + return nil, fmt.Errorf("admin.app.update: %w", err) + } + return app, nil + }) + + d.Register("admin.app.delete", func(ctx context.Context, params json.RawMessage) (any, error) { + var p appDeleteParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.AppID == "" { + return nil, NewRPCError(InvalidParams, "app_id required", nil) + } + + if err := client.DeleteApplication(ctx, p.AppID); err != nil { + return nil, fmt.Errorf("admin.app.delete: %w", err) + } + return deletedResult{Deleted: true}, nil + }) + + d.Register("admin.callbackUri.list", func(ctx context.Context, params json.RawMessage) (any, error) { + uris, err := client.ListCallbackURIs(ctx) + if err != nil { + return nil, fmt.Errorf("admin.callbackUri.list: %w", err) + } + return callbackURIListResult{CallbackURIs: uris}, nil + }) + + d.Register("admin.callbackUri.get", func(ctx context.Context, params json.RawMessage) (any, error) { + var p callbackURIGetParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.URIID == "" { + return nil, NewRPCError(InvalidParams, "uri_id required", nil) + } + + uri, err := client.GetCallbackURI(ctx, p.URIID) + if err != nil { + return nil, fmt.Errorf("admin.callbackUri.get: %w", err) + } + return uri, nil + }) + + d.Register("admin.callbackUri.create", func(ctx context.Context, params json.RawMessage) (any, error) { + var p callbackURICreateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + uri, err := client.CreateCallbackURI(ctx, &p.CreateCallbackURIRequest) + if err != nil { + return nil, fmt.Errorf("admin.callbackUri.create: %w", err) + } + return uri, nil + }) + + d.Register("admin.callbackUri.update", func(ctx context.Context, params json.RawMessage) (any, error) { + var p callbackURIUpdateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.URIID == "" { + return nil, NewRPCError(InvalidParams, "uri_id required", nil) + } + + uri, err := client.UpdateCallbackURI(ctx, p.URIID, &p.UpdateCallbackURIRequest) + if err != nil { + return nil, fmt.Errorf("admin.callbackUri.update: %w", err) + } + return uri, nil + }) + + d.Register("admin.callbackUri.delete", func(ctx context.Context, params json.RawMessage) (any, error) { + var p callbackURIDeleteParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.URIID == "" { + return nil, NewRPCError(InvalidParams, "uri_id required", nil) + } + + if err := client.DeleteCallbackURI(ctx, p.URIID); err != nil { + return nil, fmt.Errorf("admin.callbackUri.delete: %w", err) + } + return deletedResult{Deleted: true}, nil + }) + + d.Register("admin.connector.list", func(ctx context.Context, params json.RawMessage) (any, error) { + connectors, err := client.ListConnectors(ctx) + if err != nil { + return nil, fmt.Errorf("admin.connector.list: %w", err) + } + return connectorListResult{Connectors: connectors}, nil + }) + + d.Register("admin.connector.get", func(ctx context.Context, params json.RawMessage) (any, error) { + var p connectorGetParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.ConnectorID == "" { + return nil, NewRPCError(InvalidParams, "connector_id required", nil) + } + + connector, err := client.GetConnector(ctx, p.ConnectorID) + if err != nil { + return nil, fmt.Errorf("admin.connector.get: %w", err) + } + return connector, nil + }) + + d.Register("admin.connector.create", func(ctx context.Context, params json.RawMessage) (any, error) { + var p connectorCreateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + connector, err := client.CreateConnector(ctx, &p.CreateConnectorRequest) + if err != nil { + return nil, fmt.Errorf("admin.connector.create: %w", err) + } + return connector, nil + }) + + d.Register("admin.connector.update", func(ctx context.Context, params json.RawMessage) (any, error) { + var p connectorUpdateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.ConnectorID == "" { + return nil, NewRPCError(InvalidParams, "connector_id required", nil) + } + + connector, err := client.UpdateConnector(ctx, p.ConnectorID, &p.UpdateConnectorRequest) + if err != nil { + return nil, fmt.Errorf("admin.connector.update: %w", err) + } + return connector, nil + }) + + d.Register("admin.connector.delete", func(ctx context.Context, params json.RawMessage) (any, error) { + var p connectorDeleteParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.ConnectorID == "" { + return nil, NewRPCError(InvalidParams, "connector_id required", nil) + } + + if err := client.DeleteConnector(ctx, p.ConnectorID); err != nil { + return nil, fmt.Errorf("admin.connector.delete: %w", err) + } + return deletedResult{Deleted: true}, nil + }) + + d.Register("admin.credential.list", func(ctx context.Context, params json.RawMessage) (any, error) { + var p credentialListParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.ConnectorID == "" { + return nil, NewRPCError(InvalidParams, "connector_id required", nil) + } + + credentials, err := client.ListCredentials(ctx, p.ConnectorID) + if err != nil { + return nil, fmt.Errorf("admin.credential.list: %w", err) + } + return credentialListResult{Credentials: credentials}, nil + }) + + d.Register("admin.credential.get", func(ctx context.Context, params json.RawMessage) (any, error) { + var p credentialGetParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.CredentialID == "" { + return nil, NewRPCError(InvalidParams, "credential_id required", nil) + } + + credential, err := client.GetCredential(ctx, p.CredentialID) + if err != nil { + return nil, fmt.Errorf("admin.credential.get: %w", err) + } + return credential, nil + }) + + d.Register("admin.credential.create", func(ctx context.Context, params json.RawMessage) (any, error) { + var p credentialCreateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.ConnectorID == "" { + return nil, NewRPCError(InvalidParams, "connector_id required", nil) + } + + credential, err := client.CreateCredential(ctx, p.ConnectorID, &p.CreateCredentialRequest) + if err != nil { + return nil, fmt.Errorf("admin.credential.create: %w", err) + } + return credential, nil + }) + + d.Register("admin.credential.update", func(ctx context.Context, params json.RawMessage) (any, error) { + var p credentialUpdateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.CredentialID == "" { + return nil, NewRPCError(InvalidParams, "credential_id required", nil) + } + + credential, err := client.UpdateCredential(ctx, p.CredentialID, &p.UpdateCredentialRequest) + if err != nil { + return nil, fmt.Errorf("admin.credential.update: %w", err) + } + return credential, nil + }) + + d.Register("admin.credential.delete", func(ctx context.Context, params json.RawMessage) (any, error) { + var p credentialDeleteParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.CredentialID == "" { + return nil, NewRPCError(InvalidParams, "credential_id required", nil) + } + + if err := client.DeleteCredential(ctx, p.CredentialID); err != nil { + return nil, fmt.Errorf("admin.credential.delete: %w", err) + } + return deletedResult{Deleted: true}, nil + }) + + d.Register("workspace.list", func(ctx context.Context, params json.RawMessage) (any, error) { + workspaces, err := client.ListWorkspaces(ctx) + if err != nil { + return nil, fmt.Errorf("workspace.list: %w", err) + } + return workspaceListResult{Workspaces: workspaces}, nil + }) + + d.Register("workspace.get", func(ctx context.Context, params json.RawMessage) (any, error) { + var p workspaceGetParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.WorkspaceID == "" { + return nil, NewRPCError(InvalidParams, "workspace_id required", nil) + } + + workspace, err := client.GetWorkspace(ctx, p.WorkspaceID) + if err != nil { + return nil, fmt.Errorf("workspace.get: %w", err) + } + return workspace, nil + }) + + d.Register("workspace.create", func(ctx context.Context, params json.RawMessage) (any, error) { + var p workspaceCreateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + workspace, err := client.CreateWorkspace(ctx, &p.CreateWorkspaceRequest) + if err != nil { + return nil, fmt.Errorf("workspace.create: %w", err) + } + return workspace, nil + }) + + d.Register("workspace.update", func(ctx context.Context, params json.RawMessage) (any, error) { + var p workspaceUpdateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.WorkspaceID == "" { + return nil, NewRPCError(InvalidParams, "workspace_id required", nil) + } + + workspace, err := client.UpdateWorkspace(ctx, p.WorkspaceID, &p.UpdateWorkspaceRequest) + if err != nil { + return nil, fmt.Errorf("workspace.update: %w", err) + } + return workspace, nil + }) + + d.Register("workspace.delete", func(ctx context.Context, params json.RawMessage) (any, error) { + var p workspaceDeleteParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.WorkspaceID == "" { + return nil, NewRPCError(InvalidParams, "workspace_id required", nil) + } + + if err := client.DeleteWorkspace(ctx, p.WorkspaceID); err != nil { + return nil, fmt.Errorf("workspace.delete: %w", err) + } + return deletedResult{Deleted: true}, nil + }) + + d.Register("workspace.assignGrants", func(ctx context.Context, params json.RawMessage) (any, error) { + var p workspaceAssignParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.WorkspaceID == "" { + return nil, NewRPCError(InvalidParams, "workspace_id required", nil) + } + + result, err := client.AssignWorkspaceGrants(ctx, p.WorkspaceID, &p.WorkspaceAssignRequest) + if err != nil { + return nil, fmt.Errorf("workspace.assignGrants: %w", err) + } + return result, nil + }) + + d.Register("admin.grants.listAll", func(ctx context.Context, params json.RawMessage) (any, error) { + var p grantsListAllParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + grants, err := client.ListAllGrants(ctx, &p.GrantsQueryParams) + if err != nil { + return nil, fmt.Errorf("admin.grants.listAll: %w", err) + } + return grantsListAllResult{Grants: grants}, nil + }) + + d.Register("admin.grants.stats", func(ctx context.Context, params json.RawMessage) (any, error) { + stats, err := client.GetGrantStats(ctx) + if err != nil { + return nil, fmt.Errorf("admin.grants.stats: %w", err) + } + return stats, nil + }) +} diff --git a/internal/adapters/rpcserver/handlers_admin_test.go b/internal/adapters/rpcserver/handlers_admin_test.go new file mode 100644 index 0000000..bc937aa --- /dev/null +++ b/internal/adapters/rpcserver/handlers_admin_test.go @@ -0,0 +1,461 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type fakeAdminClient struct { + ports.AdminClient + + listApplications func(context.Context) ([]domain.Application, error) + createCallbackURI func(context.Context, *domain.CreateCallbackURIRequest) (*domain.CallbackURI, error) + updateConnector func(context.Context, string, *domain.UpdateConnectorRequest) (*domain.Connector, error) + listCredentials func(context.Context, string) ([]domain.ConnectorCredential, error) + getCredential func(context.Context, string) (*domain.ConnectorCredential, error) + getWorkspace func(context.Context, string) (*domain.Workspace, error) + assignWorkspaceGrants func(context.Context, string, *domain.WorkspaceAssignRequest) (*domain.WorkspaceAssignResult, error) + listAllGrants func(context.Context, *domain.GrantsQueryParams) ([]domain.Grant, error) + getGrantStats func(context.Context) (*domain.GrantStats, error) +} + +func (f *fakeAdminClient) ListApplications(ctx context.Context) ([]domain.Application, error) { + if f.listApplications == nil { + return nil, errors.New("unexpected ListApplications") + } + return f.listApplications(ctx) +} + +func (f *fakeAdminClient) CreateCallbackURI(ctx context.Context, req *domain.CreateCallbackURIRequest) (*domain.CallbackURI, error) { + if f.createCallbackURI == nil { + return nil, errors.New("unexpected CreateCallbackURI") + } + return f.createCallbackURI(ctx, req) +} + +func (f *fakeAdminClient) UpdateConnector(ctx context.Context, connectorID string, req *domain.UpdateConnectorRequest) (*domain.Connector, error) { + if f.updateConnector == nil { + return nil, errors.New("unexpected UpdateConnector") + } + return f.updateConnector(ctx, connectorID, req) +} + +func (f *fakeAdminClient) ListCredentials(ctx context.Context, connectorID string) ([]domain.ConnectorCredential, error) { + if f.listCredentials == nil { + return nil, errors.New("unexpected ListCredentials") + } + return f.listCredentials(ctx, connectorID) +} + +func (f *fakeAdminClient) GetCredential(ctx context.Context, credentialID string) (*domain.ConnectorCredential, error) { + if f.getCredential == nil { + return nil, errors.New("unexpected GetCredential") + } + return f.getCredential(ctx, credentialID) +} + +func (f *fakeAdminClient) GetWorkspace(ctx context.Context, workspaceID string) (*domain.Workspace, error) { + if f.getWorkspace == nil { + return nil, errors.New("unexpected GetWorkspace") + } + return f.getWorkspace(ctx, workspaceID) +} + +func (f *fakeAdminClient) AssignWorkspaceGrants(ctx context.Context, workspaceID string, req *domain.WorkspaceAssignRequest) (*domain.WorkspaceAssignResult, error) { + if f.assignWorkspaceGrants == nil { + return nil, errors.New("unexpected AssignWorkspaceGrants") + } + return f.assignWorkspaceGrants(ctx, workspaceID, req) +} + +func (f *fakeAdminClient) ListAllGrants(ctx context.Context, params *domain.GrantsQueryParams) ([]domain.Grant, error) { + if f.listAllGrants == nil { + return nil, errors.New("unexpected ListAllGrants") + } + return f.listAllGrants(ctx, params) +} + +func (f *fakeAdminClient) GetGrantStats(ctx context.Context) (*domain.GrantStats, error) { + if f.getGrantStats == nil { + return nil, errors.New("unexpected GetGrantStats") + } + return f.getGrantStats(ctx) +} + +func TestRegisterAdminHandlers_RegistersAllMethods(t *testing.T) { + d := NewDispatcher() + RegisterAdminHandlers(d, &fakeAdminClient{}) + + methods := []string{ + "admin.app.list", + "admin.app.get", + "admin.app.create", + "admin.app.update", + "admin.app.delete", + "admin.callbackUri.list", + "admin.callbackUri.get", + "admin.callbackUri.create", + "admin.callbackUri.update", + "admin.callbackUri.delete", + "admin.connector.list", + "admin.connector.get", + "admin.connector.create", + "admin.connector.update", + "admin.connector.delete", + "admin.credential.list", + "admin.credential.get", + "admin.credential.create", + "admin.credential.update", + "admin.credential.delete", + "workspace.list", + "workspace.get", + "workspace.create", + "workspace.update", + "workspace.delete", + "workspace.assignGrants", + "admin.grants.listAll", + "admin.grants.stats", + } + + if len(d.handlers) != len(methods) { + t.Fatalf("registered handlers = %d, want %d", len(d.handlers), len(methods)) + } + for _, method := range methods { + if d.handlers[method] == nil { + t.Fatalf("handler %q not registered", method) + } + } +} + +func TestRegisterAdminHandlers(t *testing.T) { + clientErr := errors.New("client unavailable") + + tests := []struct { + name string + method string + params string + client *fakeAdminClient + assert func(*testing.T, rpcTestResponse) + }{ + { + name: "admin.app.list returns applications", + method: "admin.app.list", + params: `{}`, + client: &fakeAdminClient{ + listApplications: func(ctx context.Context) ([]domain.Application, error) { + return []domain.Application{{ID: "app-1"}}, nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + + var result struct { + Applications []domain.Application `json:"applications"` + } + unmarshalResult(t, resp, &result) + if len(result.Applications) != 1 || result.Applications[0].ID != "app-1" { + t.Fatalf("applications = %#v, want app-1", result.Applications) + } + }, + }, + { + name: "admin.app.get missing app_id returns invalid params", + method: "admin.app.get", + params: `{}`, + client: &fakeAdminClient{}, + assert: func(t *testing.T, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + }, + }, + { + name: "admin.app.list client error maps to internal error", + method: "admin.app.list", + params: `{}`, + client: &fakeAdminClient{ + listApplications: func(ctx context.Context) ([]domain.Application, error) { + return nil, clientErr + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InternalError) + }, + }, + { + name: "admin.callbackUri.create embeds request fields", + method: "admin.callbackUri.create", + params: `{"url":"http://localhost/callback","platform":"web"}`, + client: &fakeAdminClient{ + createCallbackURI: func(ctx context.Context, req *domain.CreateCallbackURIRequest) (*domain.CallbackURI, error) { + if req.URL != "http://localhost/callback" || req.Platform != "web" { + t.Fatalf("request = %#v, want callback URL and web platform", req) + } + return &domain.CallbackURI{ID: "uri-1", URL: req.URL, Platform: req.Platform}, nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + + var uri domain.CallbackURI + unmarshalResult(t, resp, &uri) + if uri.ID != "uri-1" { + t.Fatalf("callback URI = %#v, want uri-1", uri) + } + }, + }, + { + name: "admin.callbackUri.get missing uri_id returns invalid params", + method: "admin.callbackUri.get", + params: `{}`, + client: &fakeAdminClient{}, + assert: func(t *testing.T, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + }, + }, + { + name: "admin.callbackUri.create client error maps to internal error", + method: "admin.callbackUri.create", + params: `{"url":"http://localhost/callback","platform":"web"}`, + client: &fakeAdminClient{ + createCallbackURI: func(ctx context.Context, req *domain.CreateCallbackURIRequest) (*domain.CallbackURI, error) { + return nil, clientErr + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InternalError) + }, + }, + { + name: "admin.connector.update embeds request fields", + method: "admin.connector.update", + params: `{"connector_id":"connector-1","name":"Google"}`, + client: &fakeAdminClient{ + updateConnector: func(ctx context.Context, connectorID string, req *domain.UpdateConnectorRequest) (*domain.Connector, error) { + if connectorID != "connector-1" { + t.Fatalf("connectorID = %q, want connector-1", connectorID) + } + if req.Name == nil || *req.Name != "Google" { + t.Fatalf("Name = %#v, want Google", req.Name) + } + return &domain.Connector{ID: connectorID, Name: *req.Name, Provider: "google"}, nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + + var connector domain.Connector + unmarshalResult(t, resp, &connector) + if connector.ID != "connector-1" || connector.Name != "Google" { + t.Fatalf("connector = %#v, want connector-1 Google", connector) + } + }, + }, + { + name: "admin.connector.update missing connector_id returns invalid params", + method: "admin.connector.update", + params: `{"name":"Google"}`, + client: &fakeAdminClient{}, + assert: func(t *testing.T, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + }, + }, + { + name: "admin.connector.update client error maps to internal error", + method: "admin.connector.update", + params: `{"connector_id":"connector-1","name":"Google"}`, + client: &fakeAdminClient{ + updateConnector: func(ctx context.Context, connectorID string, req *domain.UpdateConnectorRequest) (*domain.Connector, error) { + return nil, clientErr + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InternalError) + }, + }, + { + name: "admin.credential.list returns credentials", + method: "admin.credential.list", + params: `{"connector_id":"connector-1"}`, + client: &fakeAdminClient{ + listCredentials: func(ctx context.Context, connectorID string) ([]domain.ConnectorCredential, error) { + if connectorID != "connector-1" { + t.Fatalf("connectorID = %q, want connector-1", connectorID) + } + return []domain.ConnectorCredential{{ID: "credential-1", Name: "OAuth", CredentialType: "oauth"}}, nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + + var result struct { + Credentials []domain.ConnectorCredential `json:"credentials"` + } + unmarshalResult(t, resp, &result) + if len(result.Credentials) != 1 || result.Credentials[0].ID != "credential-1" { + t.Fatalf("credentials = %#v, want credential-1", result.Credentials) + } + }, + }, + { + name: "admin.credential.list missing connector_id returns invalid params", + method: "admin.credential.list", + params: `{}`, + client: &fakeAdminClient{}, + assert: func(t *testing.T, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + }, + }, + { + name: "admin.credential.get client error maps to internal error", + method: "admin.credential.get", + params: `{"credential_id":"credential-1"}`, + client: &fakeAdminClient{ + getCredential: func(ctx context.Context, credentialID string) (*domain.ConnectorCredential, error) { + return nil, clientErr + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InternalError) + }, + }, + { + name: "workspace.assignGrants embeds request fields", + method: "workspace.assignGrants", + params: `{"workspace_id":"workspace-1","assign_grants":["grant-1"],"remove_grants":["grant-2"]}`, + client: &fakeAdminClient{ + assignWorkspaceGrants: func(ctx context.Context, workspaceID string, req *domain.WorkspaceAssignRequest) (*domain.WorkspaceAssignResult, error) { + if workspaceID != "workspace-1" { + t.Fatalf("workspaceID = %q, want workspace-1", workspaceID) + } + if len(req.AssignGrants) != 1 || req.AssignGrants[0] != "grant-1" { + t.Fatalf("AssignGrants = %#v, want grant-1", req.AssignGrants) + } + if len(req.RemoveGrants) != 1 || req.RemoveGrants[0] != "grant-2" { + t.Fatalf("RemoveGrants = %#v, want grant-2", req.RemoveGrants) + } + return &domain.WorkspaceAssignResult{WorkspaceID: workspaceID, GrantsAssigned: req.AssignGrants, GrantsRemoved: req.RemoveGrants}, nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + + var result domain.WorkspaceAssignResult + unmarshalResult(t, resp, &result) + if result.WorkspaceID != "workspace-1" || len(result.GrantsAssigned) != 1 || result.GrantsAssigned[0] != "grant-1" { + t.Fatalf("assign result = %#v, want workspace-1 grant-1", result) + } + }, + }, + { + name: "workspace.assignGrants missing workspace_id returns invalid params", + method: "workspace.assignGrants", + params: `{"assign_grants":["grant-1"]}`, + client: &fakeAdminClient{}, + assert: func(t *testing.T, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + }, + }, + { + name: "workspace.get client error maps to internal error", + method: "workspace.get", + params: `{"workspace_id":"workspace-1"}`, + client: &fakeAdminClient{ + getWorkspace: func(ctx context.Context, workspaceID string) (*domain.Workspace, error) { + return nil, clientErr + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InternalError) + }, + }, + { + name: "admin.grants.listAll returns grants and forwards params", + method: "admin.grants.listAll", + params: `{"limit":10,"offset":5,"connector_id":"connector-1","status":"valid"}`, + client: &fakeAdminClient{ + listAllGrants: func(ctx context.Context, params *domain.GrantsQueryParams) ([]domain.Grant, error) { + if params.Limit != 10 || params.Offset != 5 || params.ConnectorID != "connector-1" || params.Status != "valid" { + t.Fatalf("params = %#v, want list all query params", params) + } + return []domain.Grant{{ID: "grant-1", Email: "user@example.com", GrantStatus: "valid"}}, nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + + var result struct { + Grants []domain.Grant `json:"grants"` + } + unmarshalResult(t, resp, &result) + if len(result.Grants) != 1 || result.Grants[0].ID != "grant-1" { + t.Fatalf("grants = %#v, want grant-1", result.Grants) + } + }, + }, + { + name: "admin.grants.stats returns stats", + method: "admin.grants.stats", + params: `{}`, + client: &fakeAdminClient{ + getGrantStats: func(ctx context.Context) (*domain.GrantStats, error) { + return &domain.GrantStats{Total: 2, Valid: 1, Invalid: 1}, nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + + var stats domain.GrantStats + unmarshalResult(t, resp, &stats) + if stats.Total != 2 || stats.Valid != 1 || stats.Invalid != 1 { + t.Fatalf("stats = %#v, want total 2 valid 1 invalid 1", stats) + } + }, + }, + { + name: "admin.grants.stats client error maps to internal error", + method: "admin.grants.stats", + params: `{}`, + client: &fakeAdminClient{ + getGrantStats: func(ctx context.Context) (*domain.GrantStats, error) { + return nil, clientErr + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InternalError) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewDispatcher() + RegisterAdminHandlers(d, tt.client) + + resp := dispatchAdminRequest(t, d, tt.method, tt.params) + tt.assert(t, resp) + }) + } +} + +func dispatchAdminRequest(t *testing.T, d *Dispatcher, method, params string) rpcTestResponse { + t.Helper() + + raw := []byte(`{"jsonrpc":"2.0","id":1,"method":"` + method + `","params":` + params + `}`) + got := d.Dispatch(context.Background(), raw) + if got == nil { + t.Fatal("Dispatch() = nil, want response") + } + + var resp rpcTestResponse + if err := json.Unmarshal(got, &resp); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + if resp.JSONRPC != "2.0" { + t.Fatalf("JSONRPC = %q, want %q", resp.JSONRPC, "2.0") + } + return resp +} diff --git a/internal/adapters/rpcserver/handlers_audit.go b/internal/adapters/rpcserver/handlers_audit.go new file mode 100644 index 0000000..fdb6936 --- /dev/null +++ b/internal/adapters/rpcserver/handlers_audit.go @@ -0,0 +1,154 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type auditListParams struct { + Limit int `json:"limit,omitempty"` +} + +type auditEntriesResult struct { + Entries []domain.AuditEntry `json:"entries"` +} + +type auditStatsResult struct { + FileCount int `json:"file_count"` + TotalSizeBytes int64 `json:"total_size_bytes"` + OldestEntry *domain.AuditEntry `json:"oldest_entry"` +} + +type auditPathResult struct { + Path string `json:"path"` +} + +type auditOKResult struct { + OK bool `json:"ok"` +} + +type auditClearedResult struct { + Cleared bool `json:"cleared"` +} + +func RegisterAuditHandlers(d *Dispatcher, svc ports.AuditStore) { + d.Register("audit.list", func(ctx context.Context, params json.RawMessage) (any, error) { + var p auditListParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + entries, err := svc.List(ctx, p.Limit) + if err != nil { + return nil, fmt.Errorf("audit.list: %w", err) + } + return auditEntriesResult{Entries: entries}, nil + }) + + d.Register("audit.query", func(ctx context.Context, params json.RawMessage) (any, error) { + var p domain.AuditQueryOptions + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + entries, err := svc.Query(ctx, &p) + if err != nil { + return nil, fmt.Errorf("audit.query: %w", err) + } + return auditEntriesResult{Entries: entries}, nil + }) + + d.Register("audit.summary", func(ctx context.Context, params json.RawMessage) (any, error) { + var p struct { + Days int `json:"days,omitempty"` + } + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + summary, err := svc.Summary(ctx, p.Days) + if err != nil { + return nil, fmt.Errorf("audit.summary: %w", err) + } + return summary, nil + }) + + d.Register("audit.stats", func(_ context.Context, params json.RawMessage) (any, error) { + var p struct{} + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + fileCount, totalSizeBytes, oldestEntry, err := svc.Stats() + if err != nil { + return nil, fmt.Errorf("audit.stats: %w", err) + } + return auditStatsResult{ + FileCount: fileCount, + TotalSizeBytes: totalSizeBytes, + OldestEntry: oldestEntry, + }, nil + }) + + d.Register("audit.config.read", func(_ context.Context, params json.RawMessage) (any, error) { + var p struct{} + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + cfg, err := svc.GetConfig() + if err != nil { + return nil, fmt.Errorf("audit.config.read: %w", err) + } + return cfg, nil + }) + + d.Register("audit.config.save", func(_ context.Context, params json.RawMessage) (any, error) { + var p domain.AuditConfig + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + if err := svc.SaveConfig(&p); err != nil { + return nil, fmt.Errorf("audit.config.save: %w", err) + } + return auditOKResult{OK: true}, nil + }) + + d.Register("audit.path", func(_ context.Context, params json.RawMessage) (any, error) { + var p struct{} + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + return auditPathResult{Path: svc.Path()}, nil + }) + + d.Register("audit.clear", func(ctx context.Context, params json.RawMessage) (any, error) { + var p struct{} + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + if err := svc.Clear(ctx); err != nil { + return nil, fmt.Errorf("audit.clear: %w", err) + } + return auditClearedResult{Cleared: true}, nil + }) + + d.Register("audit.cleanup", func(ctx context.Context, params json.RawMessage) (any, error) { + var p struct{} + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + if err := svc.Cleanup(ctx); err != nil { + return nil, fmt.Errorf("audit.cleanup: %w", err) + } + return auditOKResult{OK: true}, nil + }) +} diff --git a/internal/adapters/rpcserver/handlers_audit_test.go b/internal/adapters/rpcserver/handlers_audit_test.go new file mode 100644 index 0000000..0fc0ee6 --- /dev/null +++ b/internal/adapters/rpcserver/handlers_audit_test.go @@ -0,0 +1,419 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "errors" + "testing" + "time" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type fakeAuditService struct { + ports.AuditStore + + getConfig func() (*domain.AuditConfig, error) + saveConfig func(*domain.AuditConfig) error + list func(context.Context, int) ([]domain.AuditEntry, error) + query func(context.Context, *domain.AuditQueryOptions) ([]domain.AuditEntry, error) + summary func(context.Context, int) (*domain.AuditSummary, error) + clear func(context.Context) error + path func() string + stats func() (int, int64, *domain.AuditEntry, error) + cleanup func(context.Context) error +} + +func (f *fakeAuditService) GetConfig() (*domain.AuditConfig, error) { + if f.getConfig == nil { + return nil, errors.New("unexpected GetConfig") + } + return f.getConfig() +} + +func (f *fakeAuditService) SaveConfig(cfg *domain.AuditConfig) error { + if f.saveConfig == nil { + return errors.New("unexpected SaveConfig") + } + return f.saveConfig(cfg) +} + +func (f *fakeAuditService) List(ctx context.Context, limit int) ([]domain.AuditEntry, error) { + if f.list == nil { + return nil, errors.New("unexpected List") + } + return f.list(ctx, limit) +} + +func (f *fakeAuditService) Query(ctx context.Context, opts *domain.AuditQueryOptions) ([]domain.AuditEntry, error) { + if f.query == nil { + return nil, errors.New("unexpected Query") + } + return f.query(ctx, opts) +} + +func (f *fakeAuditService) Summary(ctx context.Context, days int) (*domain.AuditSummary, error) { + if f.summary == nil { + return nil, errors.New("unexpected Summary") + } + return f.summary(ctx, days) +} + +func (f *fakeAuditService) Clear(ctx context.Context) error { + if f.clear == nil { + return errors.New("unexpected Clear") + } + return f.clear(ctx) +} + +func (f *fakeAuditService) Path() string { + if f.path == nil { + return "" + } + return f.path() +} + +func (f *fakeAuditService) Stats() (int, int64, *domain.AuditEntry, error) { + if f.stats == nil { + return 0, 0, nil, errors.New("unexpected Stats") + } + return f.stats() +} + +func (f *fakeAuditService) Cleanup(ctx context.Context) error { + if f.cleanup == nil { + return errors.New("unexpected Cleanup") + } + return f.cleanup(ctx) +} + +type auditRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + ID json.RawMessage `json:"id,omitempty"` + Result json.RawMessage `json:"result,omitempty"` + Error *RPCError `json:"error,omitempty"` +} + +func TestRegisterAuditHandlers(t *testing.T) { + now := time.Date(2026, 6, 25, 12, 30, 0, 0, time.UTC) + entry := domain.AuditEntry{ + ID: "audit-1", + Timestamp: now, + Command: "email list", + GrantID: "grant-1", + Status: domain.AuditStatusSuccess, + } + + tests := []struct { + name string + method string + params string + svc *fakeAuditService + assert func(*testing.T, auditRPCResponse) + }{ + { + name: "audit.list returns entries", + method: "audit.list", + params: `{"limit":2}`, + svc: &fakeAuditService{ + list: func(ctx context.Context, limit int) ([]domain.AuditEntry, error) { + if limit != 2 { + t.Fatalf("limit = %d, want 2", limit) + } + return []domain.AuditEntry{entry}, nil + }, + }, + assert: func(t *testing.T, resp auditRPCResponse) { + requireNoAuditRPCError(t, resp) + + var result struct { + Entries []domain.AuditEntry `json:"entries"` + } + unmarshalAuditResult(t, resp, &result) + if len(result.Entries) != 1 || result.Entries[0].ID != "audit-1" { + t.Fatalf("entries = %#v, want audit-1", result.Entries) + } + }, + }, + { + name: "audit.query forwards filters and returns entries", + method: "audit.query", + params: `{"limit":5,"since":"2026-06-01T00:00:00Z","until":"2026-06-25T00:00:00Z","command":"email list","status":"success","grant_id":"grant-1","request_id":"req-1","invoker":"ada","invoker_source":"terminal"}`, + svc: &fakeAuditService{ + query: func(ctx context.Context, opts *domain.AuditQueryOptions) ([]domain.AuditEntry, error) { + if opts.Limit != 5 || opts.Command != "email list" || opts.Status != "success" || opts.GrantID != "grant-1" || opts.RequestID != "req-1" || opts.Invoker != "ada" || opts.InvokerSource != "terminal" { + t.Fatalf("opts = %#v, want decoded query filters", opts) + } + if opts.Since.IsZero() || opts.Until.IsZero() { + t.Fatalf("opts times = %s %s, want decoded since/until", opts.Since, opts.Until) + } + return []domain.AuditEntry{entry}, nil + }, + }, + assert: func(t *testing.T, resp auditRPCResponse) { + requireNoAuditRPCError(t, resp) + + var result struct { + Entries []domain.AuditEntry `json:"entries"` + } + unmarshalAuditResult(t, resp, &result) + if len(result.Entries) != 1 || result.Entries[0].Command != "email list" { + t.Fatalf("entries = %#v, want email list", result.Entries) + } + }, + }, + { + name: "audit.summary returns summary", + method: "audit.summary", + params: `{"days":14}`, + svc: &fakeAuditService{ + summary: func(ctx context.Context, days int) (*domain.AuditSummary, error) { + if days != 14 { + t.Fatalf("days = %d, want 14", days) + } + return &domain.AuditSummary{ + StartDate: now.AddDate(0, 0, -14), + EndDate: now, + Days: 14, + TotalCommands: 3, + SuccessCount: 2, + ErrorCount: 1, + SuccessPercent: 66.67, + CommandCounts: map[string]int{"email list": 3}, + AccountCounts: map[string]int{"grant-1": 3}, + InvokerCounts: map[string]int{"ada": 3}, + TotalAPICalls: 4, + AvgResponseTime: time.Second, + APIErrorRate: 25, + }, nil + }, + }, + assert: func(t *testing.T, resp auditRPCResponse) { + requireNoAuditRPCError(t, resp) + + var result domain.AuditSummary + unmarshalAuditResult(t, resp, &result) + if result.Days != 14 || result.TotalCommands != 3 { + t.Fatalf("summary = %#v, want 14 days and 3 commands", result) + } + }, + }, + { + name: "audit.stats returns file stats", + method: "audit.stats", + params: `{}`, + svc: &fakeAuditService{ + stats: func() (int, int64, *domain.AuditEntry, error) { + return 3, 4096, &entry, nil + }, + }, + assert: func(t *testing.T, resp auditRPCResponse) { + requireNoAuditRPCError(t, resp) + + var result struct { + FileCount int `json:"file_count"` + TotalSizeBytes int64 `json:"total_size_bytes"` + OldestEntry *domain.AuditEntry `json:"oldest_entry"` + } + unmarshalAuditResult(t, resp, &result) + if result.FileCount != 3 || result.TotalSizeBytes != 4096 || result.OldestEntry == nil || result.OldestEntry.ID != "audit-1" { + t.Fatalf("stats = %#v, want file_count 3, total_size_bytes 4096, oldest audit-1", result) + } + }, + }, + { + name: "audit.config.read returns config", + method: "audit.config.read", + params: `{}`, + svc: &fakeAuditService{ + getConfig: func() (*domain.AuditConfig, error) { + return &domain.AuditConfig{ + Enabled: true, + Initialized: true, + Path: "/tmp/audit", + RetentionDays: 30, + MaxSizeMB: 100, + Format: "jsonl", + LogAPIDetails: true, + LogRequestID: true, + RotateDaily: true, + CompressOld: false, + }, nil + }, + }, + assert: func(t *testing.T, resp auditRPCResponse) { + requireNoAuditRPCError(t, resp) + + var result domain.AuditConfig + unmarshalAuditResult(t, resp, &result) + if !result.Enabled || result.Path != "/tmp/audit" || result.RetentionDays != 30 { + t.Fatalf("config = %#v, want enabled /tmp/audit retention 30", result) + } + }, + }, + { + name: "audit.config.save saves config", + method: "audit.config.save", + params: `{"enabled":true,"initialized":true,"path":"/tmp/audit","retention_days":45,"max_size_mb":200,"format":"jsonl","log_api_details":true,"log_request_id":true,"rotate_daily":true,"compress_old":true}`, + svc: &fakeAuditService{ + saveConfig: func(cfg *domain.AuditConfig) error { + if cfg == nil || !cfg.Enabled || cfg.Path != "/tmp/audit" || cfg.RetentionDays != 45 || cfg.MaxSizeMB != 200 || !cfg.CompressOld { + t.Fatalf("cfg = %#v, want decoded audit config", cfg) + } + return nil + }, + }, + assert: func(t *testing.T, resp auditRPCResponse) { + requireNoAuditRPCError(t, resp) + + var result struct { + OK bool `json:"ok"` + } + unmarshalAuditResult(t, resp, &result) + if !result.OK { + t.Fatal("ok = false, want true") + } + }, + }, + { + name: "audit.path returns path", + method: "audit.path", + params: `{}`, + svc: &fakeAuditService{ + path: func() string { + return "/tmp/audit" + }, + }, + assert: func(t *testing.T, resp auditRPCResponse) { + requireNoAuditRPCError(t, resp) + + var result struct { + Path string `json:"path"` + } + unmarshalAuditResult(t, resp, &result) + if result.Path != "/tmp/audit" { + t.Fatalf("path = %q, want /tmp/audit", result.Path) + } + }, + }, + { + name: "audit.clear clears logs", + method: "audit.clear", + params: `{}`, + svc: &fakeAuditService{ + clear: func(ctx context.Context) error { + return nil + }, + }, + assert: func(t *testing.T, resp auditRPCResponse) { + requireNoAuditRPCError(t, resp) + + var result struct { + Cleared bool `json:"cleared"` + } + unmarshalAuditResult(t, resp, &result) + if !result.Cleared { + t.Fatal("cleared = false, want true") + } + }, + }, + { + name: "audit.cleanup returns ok", + method: "audit.cleanup", + params: `{}`, + svc: &fakeAuditService{ + cleanup: func(ctx context.Context) error { + return nil + }, + }, + assert: func(t *testing.T, resp auditRPCResponse) { + requireNoAuditRPCError(t, resp) + + var result struct { + OK bool `json:"ok"` + } + unmarshalAuditResult(t, resp, &result) + if !result.OK { + t.Fatal("ok = false, want true") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewDispatcher() + RegisterAuditHandlers(d, tt.svc) + + resp := dispatchAuditRequest(t, d, tt.method, tt.params) + tt.assert(t, resp) + }) + } +} + +func TestRegisterAuditHandlers_ClientErrorMapsToInternalError(t *testing.T) { + clientErr := errors.New("audit unavailable") + d := NewDispatcher() + var loggedErr error + d.LogError = func(err error) { + loggedErr = err + } + RegisterAuditHandlers(d, &fakeAuditService{ + list: func(ctx context.Context, limit int) ([]domain.AuditEntry, error) { + return nil, clientErr + }, + }) + + resp := dispatchAuditRequest(t, d, "audit.list", `{}`) + requireAuditRPCErrorCode(t, resp, InternalError) + if !errors.Is(loggedErr, clientErr) { + t.Fatalf("logged error = %v, want wrapped %v", loggedErr, clientErr) + } +} + +func dispatchAuditRequest(t *testing.T, d *Dispatcher, method, params string) auditRPCResponse { + t.Helper() + + raw := []byte(`{"jsonrpc":"2.0","id":1,"method":"` + method + `","params":` + params + `}`) + got := d.Dispatch(context.Background(), raw) + if got == nil { + t.Fatal("Dispatch() = nil, want response") + } + + var resp auditRPCResponse + if err := json.Unmarshal(got, &resp); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + if resp.JSONRPC != "2.0" { + t.Fatalf("JSONRPC = %q, want %q", resp.JSONRPC, "2.0") + } + return resp +} + +func requireNoAuditRPCError(t *testing.T, resp auditRPCResponse) { + t.Helper() + + if resp.Error != nil { + t.Fatalf("Error = %#v, want nil", resp.Error) + } +} + +func requireAuditRPCErrorCode(t *testing.T, resp auditRPCResponse, want int) { + t.Helper() + + if resp.Error == nil { + t.Fatal("Error = nil, want RPC error") + } + if resp.Error.Code != want { + t.Fatalf("Error.Code = %d, want %d", resp.Error.Code, want) + } +} + +func unmarshalAuditResult(t *testing.T, resp auditRPCResponse, dest any) { + t.Helper() + + if err := json.Unmarshal(resp.Result, dest); err != nil { + t.Fatalf("unmarshal result: %v", err) + } +} diff --git a/internal/adapters/rpcserver/handlers_auth.go b/internal/adapters/rpcserver/handlers_auth.go new file mode 100644 index 0000000..ba99c25 --- /dev/null +++ b/internal/adapters/rpcserver/handlers_auth.go @@ -0,0 +1,128 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type authGrantGetParams struct { + GrantID string `json:"grant_id,omitempty"` +} + +type authGrantCreateCustomParams struct { + Provider string `json:"provider"` + Settings map[string]any `json:"settings,omitempty"` +} + +type authURLParams struct { + Provider string `json:"provider"` + RedirectURI string `json:"redirect_uri"` + State string `json:"state,omitempty"` + CodeChallenge string `json:"code_challenge,omitempty"` +} + +type authGrantRevokeResult struct { + Revoked bool `json:"revoked"` +} + +type authURLResult struct { + URL string `json:"url"` +} + +type authGrantExchangeParams struct { + Code string `json:"code"` + RedirectURI string `json:"redirect_uri"` + CodeVerifier string `json:"code_verifier,omitempty"` +} + +func RegisterAuthHandlers(d *Dispatcher, client ports.AuthClient, defaultGrant string) { + d.Register("auth.grant.get", func(ctx context.Context, params json.RawMessage) (any, error) { + var p authGrantGetParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + grant, err := client.GetGrant(ctx, grantID) + if err != nil { + return nil, fmt.Errorf("auth.grant.get: %w", err) + } + return grant, nil + }) + + d.Register("auth.grant.revoke", func(ctx context.Context, params json.RawMessage) (any, error) { + var p authGrantGetParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + if err := client.RevokeGrant(ctx, grantID); err != nil { + return nil, fmt.Errorf("auth.grant.revoke: %w", err) + } + return authGrantRevokeResult{Revoked: true}, nil + }) + + d.Register("auth.grant.createCustom", func(ctx context.Context, params json.RawMessage) (any, error) { + var p authGrantCreateCustomParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.Provider == "" { + return nil, NewRPCError(InvalidParams, "provider required", nil) + } + + grant, err := client.CreateCustomGrant(ctx, p.Provider, p.Settings) + if err != nil { + return nil, fmt.Errorf("auth.grant.createCustom: %w", err) + } + return grant, nil + }) + + d.Register("auth.url", func(_ context.Context, params json.RawMessage) (any, error) { + var p authURLParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.Provider == "" { + return nil, NewRPCError(InvalidParams, "provider required", nil) + } + if p.RedirectURI == "" { + return nil, NewRPCError(InvalidParams, "redirect_uri required", nil) + } + + url := client.BuildAuthURL(domain.Provider(p.Provider), p.RedirectURI, p.State, p.CodeChallenge) + return authURLResult{URL: url}, nil + }) + + d.Register("auth.grant.exchange", func(ctx context.Context, params json.RawMessage) (any, error) { + var p authGrantExchangeParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.Code == "" { + return nil, NewRPCError(InvalidParams, "code required", nil) + } + if p.RedirectURI == "" { + return nil, NewRPCError(InvalidParams, "redirect_uri required", nil) + } + + grant, err := client.ExchangeCode(ctx, p.Code, p.RedirectURI, p.CodeVerifier) + if err != nil { + return nil, fmt.Errorf("auth.grant.exchange: %w", err) + } + return grant, nil + }) +} diff --git a/internal/adapters/rpcserver/handlers_auth_test.go b/internal/adapters/rpcserver/handlers_auth_test.go new file mode 100644 index 0000000..5f59af5 --- /dev/null +++ b/internal/adapters/rpcserver/handlers_auth_test.go @@ -0,0 +1,474 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type fakeAuthClient struct { + ports.AuthClient + + buildAuthURL func(domain.Provider, string, string, string) string + getGrant func(context.Context, string) (*domain.Grant, error) + revokeGrant func(context.Context, string) error + createCustomGrant func(context.Context, string, map[string]any) (*domain.Grant, error) + exchangeCode func(context.Context, string, string, string) (*domain.Grant, error) + + buildAuthURLCalls int + getGrantCalls int + revokeGrantCalls int + createCustomGrantCalls int + exchangeCodeCalls int +} + +func (f *fakeAuthClient) ExchangeCode(ctx context.Context, code, redirectURI, codeVerifier string) (*domain.Grant, error) { + f.exchangeCodeCalls++ + if f.exchangeCode == nil { + return nil, errors.New("unexpected ExchangeCode") + } + return f.exchangeCode(ctx, code, redirectURI, codeVerifier) +} + +func (f *fakeAuthClient) BuildAuthURL(provider domain.Provider, redirectURI, state, codeChallenge string) string { + f.buildAuthURLCalls++ + if f.buildAuthURL == nil { + return "" + } + return f.buildAuthURL(provider, redirectURI, state, codeChallenge) +} + +func (f *fakeAuthClient) GetGrant(ctx context.Context, grantID string) (*domain.Grant, error) { + f.getGrantCalls++ + if f.getGrant == nil { + return nil, errors.New("unexpected GetGrant") + } + return f.getGrant(ctx, grantID) +} + +func (f *fakeAuthClient) RevokeGrant(ctx context.Context, grantID string) error { + f.revokeGrantCalls++ + if f.revokeGrant == nil { + return errors.New("unexpected RevokeGrant") + } + return f.revokeGrant(ctx, grantID) +} + +func (f *fakeAuthClient) CreateCustomGrant(ctx context.Context, provider string, settings map[string]any) (*domain.Grant, error) { + f.createCustomGrantCalls++ + if f.createCustomGrant == nil { + return nil, errors.New("unexpected CreateCustomGrant") + } + return f.createCustomGrant(ctx, provider, settings) +} + +func TestRegisterAuthHandlers_GrantGet(t *testing.T) { + clientErr := errors.New("client unavailable") + + tests := []struct { + name string + params string + defaultGrant string + client *fakeAuthClient + assert func(*testing.T, *fakeAuthClient, rpcTestResponse) + }{ + { + name: "success", + params: `{}`, + defaultGrant: "default-grant", + client: &fakeAuthClient{ + getGrant: func(ctx context.Context, grantID string) (*domain.Grant, error) { + if grantID != "default-grant" { + t.Fatalf("grantID = %q, want default-grant", grantID) + } + return &domain.Grant{ID: "default-grant", Provider: domain.ProviderGoogle, Email: "user@example.com"}, nil + }, + }, + assert: func(t *testing.T, client *fakeAuthClient, resp rpcTestResponse) { + requireNoRPCError(t, resp) + if client.getGrantCalls != 1 { + t.Fatalf("GetGrant calls = %d, want 1", client.getGrantCalls) + } + + var grant domain.Grant + unmarshalResult(t, resp, &grant) + if grant.ID != "default-grant" || grant.Provider != domain.ProviderGoogle || grant.Email != "user@example.com" { + t.Fatalf("grant = %#v, want default-grant google user@example.com", grant) + } + }, + }, + { + name: "missing grant_id when no default", + params: `{}`, + client: &fakeAuthClient{}, + assert: func(t *testing.T, client *fakeAuthClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + if client.getGrantCalls != 0 { + t.Fatalf("GetGrant calls = %d, want 0", client.getGrantCalls) + } + }, + }, + { + name: "client error", + params: `{"grant_id":"grant-1"}`, + defaultGrant: "default-grant", + client: &fakeAuthClient{ + getGrant: func(ctx context.Context, grantID string) (*domain.Grant, error) { + if grantID != "grant-1" { + t.Fatalf("grantID = %q, want grant-1", grantID) + } + return nil, clientErr + }, + }, + assert: func(t *testing.T, client *fakeAuthClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InternalError) + if client.getGrantCalls != 1 { + t.Fatalf("GetGrant calls = %d, want 1", client.getGrantCalls) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewDispatcher() + RegisterAuthHandlers(d, tt.client, tt.defaultGrant) + + resp := dispatchAuthRequest(t, d, "auth.grant.get", tt.params) + tt.assert(t, tt.client, resp) + }) + } +} + +func TestRegisterAuthHandlers_GrantRevoke(t *testing.T) { + clientErr := errors.New("client unavailable") + + tests := []struct { + name string + params string + defaultGrant string + client *fakeAuthClient + assert func(*testing.T, *fakeAuthClient, rpcTestResponse) + }{ + { + name: "success", + params: `{}`, + defaultGrant: "default-grant", + client: &fakeAuthClient{ + revokeGrant: func(ctx context.Context, grantID string) error { + if grantID != "default-grant" { + t.Fatalf("grantID = %q, want default-grant", grantID) + } + return nil + }, + }, + assert: func(t *testing.T, client *fakeAuthClient, resp rpcTestResponse) { + requireNoRPCError(t, resp) + if client.revokeGrantCalls != 1 { + t.Fatalf("RevokeGrant calls = %d, want 1", client.revokeGrantCalls) + } + + var result authGrantRevokeResult + unmarshalResult(t, resp, &result) + if !result.Revoked { + t.Fatal("revoked = false, want true") + } + }, + }, + { + name: "missing grant", + params: `{}`, + client: &fakeAuthClient{}, + assert: func(t *testing.T, client *fakeAuthClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + if client.revokeGrantCalls != 0 { + t.Fatalf("RevokeGrant calls = %d, want 0", client.revokeGrantCalls) + } + }, + }, + { + name: "client error", + params: `{"grant_id":"grant-1"}`, + defaultGrant: "default-grant", + client: &fakeAuthClient{ + revokeGrant: func(ctx context.Context, grantID string) error { + if grantID != "grant-1" { + t.Fatalf("grantID = %q, want grant-1", grantID) + } + return clientErr + }, + }, + assert: func(t *testing.T, client *fakeAuthClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InternalError) + if client.revokeGrantCalls != 1 { + t.Fatalf("RevokeGrant calls = %d, want 1", client.revokeGrantCalls) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewDispatcher() + RegisterAuthHandlers(d, tt.client, tt.defaultGrant) + + resp := dispatchAuthRequest(t, d, "auth.grant.revoke", tt.params) + tt.assert(t, tt.client, resp) + }) + } +} + +func TestRegisterAuthHandlers_GrantCreateCustom(t *testing.T) { + clientErr := errors.New("client unavailable") + + tests := []struct { + name string + params string + client *fakeAuthClient + assert func(*testing.T, *fakeAuthClient, rpcTestResponse) + }{ + { + name: "success", + params: `{"provider":"imap","settings":{"username":"user@example.com","host":"imap.example.com"}}`, + client: &fakeAuthClient{ + createCustomGrant: func(ctx context.Context, provider string, settings map[string]any) (*domain.Grant, error) { + if provider != "imap" { + t.Fatalf("provider = %q, want imap", provider) + } + if settings["username"] != "user@example.com" || settings["host"] != "imap.example.com" { + t.Fatalf("settings = %#v, want username and host", settings) + } + return &domain.Grant{ID: "grant-1", Provider: domain.ProviderIMAP, Email: "user@example.com"}, nil + }, + }, + assert: func(t *testing.T, client *fakeAuthClient, resp rpcTestResponse) { + requireNoRPCError(t, resp) + if client.createCustomGrantCalls != 1 { + t.Fatalf("CreateCustomGrant calls = %d, want 1", client.createCustomGrantCalls) + } + + var grant domain.Grant + unmarshalResult(t, resp, &grant) + if grant.ID != "grant-1" || grant.Provider != domain.ProviderIMAP { + t.Fatalf("grant = %#v, want grant-1 imap", grant) + } + }, + }, + { + name: "missing provider", + params: `{"settings":{"username":"user@example.com"}}`, + client: &fakeAuthClient{}, + assert: func(t *testing.T, client *fakeAuthClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + if client.createCustomGrantCalls != 0 { + t.Fatalf("CreateCustomGrant calls = %d, want 0", client.createCustomGrantCalls) + } + }, + }, + { + name: "client error", + params: `{"provider":"imap"}`, + client: &fakeAuthClient{ + createCustomGrant: func(ctx context.Context, provider string, settings map[string]any) (*domain.Grant, error) { + if provider != "imap" { + t.Fatalf("provider = %q, want imap", provider) + } + if settings != nil { + t.Fatalf("settings = %#v, want nil", settings) + } + return nil, clientErr + }, + }, + assert: func(t *testing.T, client *fakeAuthClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InternalError) + if client.createCustomGrantCalls != 1 { + t.Fatalf("CreateCustomGrant calls = %d, want 1", client.createCustomGrantCalls) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewDispatcher() + RegisterAuthHandlers(d, tt.client, "") + + resp := dispatchAuthRequest(t, d, "auth.grant.createCustom", tt.params) + tt.assert(t, tt.client, resp) + }) + } +} + +func TestRegisterAuthHandlers_URL(t *testing.T) { + tests := []struct { + name string + params string + client *fakeAuthClient + assert func(*testing.T, *fakeAuthClient, rpcTestResponse) + }{ + { + name: "success", + params: `{"provider":"google","redirect_uri":"http://localhost:8080/callback","state":"state-1","code_challenge":"challenge-1"}`, + client: &fakeAuthClient{ + buildAuthURL: func(provider domain.Provider, redirectURI, state, codeChallenge string) string { + if provider != domain.ProviderGoogle { + t.Fatalf("provider = %q, want google", provider) + } + if redirectURI != "http://localhost:8080/callback" { + t.Fatalf("redirectURI = %q, want callback URI", redirectURI) + } + if state != "state-1" || codeChallenge != "challenge-1" { + t.Fatalf("state/codeChallenge = %q/%q, want state-1/challenge-1", state, codeChallenge) + } + return "https://api.example.test/oauth?provider=google" + }, + }, + assert: func(t *testing.T, client *fakeAuthClient, resp rpcTestResponse) { + requireNoRPCError(t, resp) + if client.buildAuthURLCalls != 1 { + t.Fatalf("BuildAuthURL calls = %d, want 1", client.buildAuthURLCalls) + } + if client.getGrantCalls != 0 || client.revokeGrantCalls != 0 || client.createCustomGrantCalls != 0 { + t.Fatalf("API calls = get:%d revoke:%d create:%d, want none", client.getGrantCalls, client.revokeGrantCalls, client.createCustomGrantCalls) + } + + var result authURLResult + unmarshalResult(t, resp, &result) + if result.URL == "" { + t.Fatal("url = empty, want non-empty") + } + }, + }, + { + name: "missing provider", + params: `{"redirect_uri":"http://localhost:8080/callback"}`, + client: &fakeAuthClient{}, + assert: func(t *testing.T, client *fakeAuthClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + if client.buildAuthURLCalls != 0 { + t.Fatalf("BuildAuthURL calls = %d, want 0", client.buildAuthURLCalls) + } + }, + }, + { + name: "missing redirect_uri", + params: `{"provider":"google"}`, + client: &fakeAuthClient{}, + assert: func(t *testing.T, client *fakeAuthClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + if client.buildAuthURLCalls != 0 { + t.Fatalf("BuildAuthURL calls = %d, want 0", client.buildAuthURLCalls) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewDispatcher() + RegisterAuthHandlers(d, tt.client, "") + + resp := dispatchAuthRequest(t, d, "auth.url", tt.params) + tt.assert(t, tt.client, resp) + }) + } +} + +func TestRegisterAuthHandlers_GrantExchange(t *testing.T) { + clientErr := errors.New("client unavailable") + + tests := []struct { + name string + params string + client *fakeAuthClient + assert func(*testing.T, *fakeAuthClient, rpcTestResponse) + }{ + { + name: "success", + params: `{"code":"auth-code-1","redirect_uri":"http://localhost:8080/callback","code_verifier":"verifier-1"}`, + client: &fakeAuthClient{ + exchangeCode: func(_ context.Context, code, redirectURI, codeVerifier string) (*domain.Grant, error) { + if code != "auth-code-1" || redirectURI != "http://localhost:8080/callback" || codeVerifier != "verifier-1" { + t.Fatalf("args = %q/%q/%q, want auth-code-1/callback/verifier-1", code, redirectURI, codeVerifier) + } + return &domain.Grant{ID: "grant-123"}, nil + }, + }, + assert: func(t *testing.T, client *fakeAuthClient, resp rpcTestResponse) { + requireNoRPCError(t, resp) + if client.exchangeCodeCalls != 1 { + t.Fatalf("ExchangeCode calls = %d, want 1", client.exchangeCodeCalls) + } + var grant domain.Grant + unmarshalResult(t, resp, &grant) + if grant.ID != "grant-123" { + t.Fatalf("grant ID = %q, want grant-123", grant.ID) + } + }, + }, + { + name: "missing code", + params: `{"redirect_uri":"http://localhost:8080/callback"}`, + client: &fakeAuthClient{}, + assert: func(t *testing.T, client *fakeAuthClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + if client.exchangeCodeCalls != 0 { + t.Fatalf("ExchangeCode calls = %d, want 0", client.exchangeCodeCalls) + } + }, + }, + { + name: "missing redirect_uri", + params: `{"code":"auth-code-1"}`, + client: &fakeAuthClient{}, + assert: func(t *testing.T, client *fakeAuthClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + }, + }, + { + name: "client error", + params: `{"code":"auth-code-1","redirect_uri":"http://localhost:8080/callback"}`, + client: &fakeAuthClient{ + exchangeCode: func(context.Context, string, string, string) (*domain.Grant, error) { + return nil, clientErr + }, + }, + assert: func(t *testing.T, _ *fakeAuthClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InternalError) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewDispatcher() + RegisterAuthHandlers(d, tt.client, "") + + resp := dispatchAuthRequest(t, d, "auth.grant.exchange", tt.params) + tt.assert(t, tt.client, resp) + }) + } +} + +func dispatchAuthRequest(t *testing.T, d *Dispatcher, method, params string) rpcTestResponse { + t.Helper() + + raw := []byte(`{"jsonrpc":"2.0","id":1,"method":"` + method + `","params":` + params + `}`) + got := d.Dispatch(context.Background(), raw) + if got == nil { + t.Fatal("Dispatch() = nil, want response") + } + + var resp rpcTestResponse + if err := json.Unmarshal(got, &resp); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + if resp.JSONRPC != "2.0" { + t.Fatalf("JSONRPC = %q, want %q", resp.JSONRPC, "2.0") + } + return resp +} diff --git a/internal/adapters/rpcserver/handlers_calendar.go b/internal/adapters/rpcserver/handlers_calendar.go new file mode 100644 index 0000000..d9e3615 --- /dev/null +++ b/internal/adapters/rpcserver/handlers_calendar.go @@ -0,0 +1,117 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type calendarListParams struct { + GrantID string `json:"grant_id,omitempty"` +} + +type calendarListResult struct { + Calendars []domain.Calendar `json:"calendars"` +} + +type eventListParams struct { + GrantID string `json:"grant_id,omitempty"` + CalendarID string `json:"calendar_id,omitempty"` + Limit int `json:"limit,omitempty"` + PageToken string `json:"page_token,omitempty"` + UpdatedAfter int64 `json:"updated_after,omitempty"` + Start int64 `json:"start,omitempty"` + End int64 `json:"end,omitempty"` +} + +type eventListResult struct { + Events []domain.Event `json:"events"` + NextCursor string `json:"next_cursor"` + HasMore bool `json:"has_more"` +} + +type eventGetParams struct { + GrantID string `json:"grant_id,omitempty"` + CalendarID string `json:"calendar_id,omitempty"` + EventID string `json:"event_id"` +} + +func RegisterCalendarHandlers(d *Dispatcher, client ports.CalendarClient, defaultGrant string) { + d.Register("calendar.list", func(ctx context.Context, params json.RawMessage) (any, error) { + var p calendarListParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + calendars, err := client.GetCalendars(ctx, grantID) + if err != nil { + return nil, fmt.Errorf("calendar.list: %w", err) + } + return calendarListResult{Calendars: calendars}, nil + }) + + d.Register("event.list", func(ctx context.Context, params json.RawMessage) (any, error) { + var p eventListParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + calendarID := p.CalendarID + if calendarID == "" { + calendarID = "primary" + } + + resp, err := client.GetEventsWithCursor(ctx, grantID, calendarID, &domain.EventQueryParams{ + Limit: p.Limit, + PageToken: p.PageToken, + UpdatedAfter: p.UpdatedAfter, + Start: p.Start, + End: p.End, + }) + if err != nil { + return nil, fmt.Errorf("event.list: %w", err) + } + return eventListResult{ + Events: resp.Data, + NextCursor: resp.Pagination.NextCursor, + HasMore: resp.Pagination.HasMore, + }, nil + }) + + d.Register("event.get", func(ctx context.Context, params json.RawMessage) (any, error) { + var p eventGetParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.EventID == "" { + return nil, NewRPCError(InvalidParams, "event_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + calendarID := p.CalendarID + if calendarID == "" { + calendarID = "primary" + } + + event, err := client.GetEvent(ctx, grantID, calendarID, p.EventID) + if err != nil { + return nil, fmt.Errorf("event.get: %w", err) + } + return event, nil + }) +} diff --git a/internal/adapters/rpcserver/handlers_calendar_ext.go b/internal/adapters/rpcserver/handlers_calendar_ext.go new file mode 100644 index 0000000..ebe68ce --- /dev/null +++ b/internal/adapters/rpcserver/handlers_calendar_ext.go @@ -0,0 +1,360 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type calendarGetParams struct { + GrantID string `json:"grant_id,omitempty"` + CalendarID string `json:"calendar_id"` +} + +type calendarCreateParams struct { + GrantID string `json:"grant_id,omitempty"` + domain.CreateCalendarRequest +} + +type calendarUpdateParams struct { + GrantID string `json:"grant_id,omitempty"` + CalendarID string `json:"calendar_id"` + domain.UpdateCalendarRequest +} + +type calendarDeleteParams struct { + GrantID string `json:"grant_id,omitempty"` + CalendarID string `json:"calendar_id"` +} + +type calendarFreeBusyParams struct { + GrantID string `json:"grant_id,omitempty"` + domain.FreeBusyRequest +} + +type calendarAvailabilityParams struct { + domain.AvailabilityRequest +} + +type calendarResourcesParams struct { + GrantID string `json:"grant_id,omitempty"` +} + +type calendarResourcesResult struct { + Resources []domain.RoomResource `json:"resources"` +} + +type eventImportParams struct { + GrantID string `json:"grant_id,omitempty"` + domain.EventQueryParams +} + +type eventListPayloadResult struct { + Events []domain.Event `json:"events"` +} + +type eventRecurringListParams struct { + GrantID string `json:"grant_id,omitempty"` + CalendarID string `json:"calendar_id,omitempty"` + MasterEventID string `json:"master_event_id"` + domain.EventQueryParams +} + +type eventRecurringUpdateParams struct { + GrantID string `json:"grant_id,omitempty"` + CalendarID string `json:"calendar_id,omitempty"` + EventID string `json:"event_id"` + domain.UpdateEventRequest +} + +type eventRecurringDeleteParams struct { + GrantID string `json:"grant_id,omitempty"` + CalendarID string `json:"calendar_id,omitempty"` + EventID string `json:"event_id"` +} + +type virtualCalendarCreateParams struct { + Email string `json:"email"` +} + +type virtualCalendarListResult struct { + Grants []domain.VirtualCalendarGrant `json:"grants"` +} + +type virtualCalendarIDParams struct { + GrantID string `json:"grant_id"` +} + +// RegisterCalendarExtHandlers registers calendar CRUD, availability/free-busy, +// recurring-instance, room resource, import, and virtual-calendar-grant methods. +func RegisterCalendarExtHandlers(d *Dispatcher, client ports.CalendarClient, defaultGrant string) { + d.Register("calendar.get", func(ctx context.Context, params json.RawMessage) (any, error) { + var p calendarGetParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.CalendarID == "" { + return nil, NewRPCError(InvalidParams, "calendar_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + calendar, err := client.GetCalendar(ctx, grantID, p.CalendarID) + if err != nil { + return nil, fmt.Errorf("calendar.get: %w", err) + } + return calendar, nil + }) + + d.Register("calendar.create", func(ctx context.Context, params json.RawMessage) (any, error) { + var p calendarCreateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + calendar, err := client.CreateCalendar(ctx, grantID, &p.CreateCalendarRequest) + if err != nil { + return nil, fmt.Errorf("calendar.create: %w", err) + } + return calendar, nil + }) + + d.Register("calendar.update", func(ctx context.Context, params json.RawMessage) (any, error) { + var p calendarUpdateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.CalendarID == "" { + return nil, NewRPCError(InvalidParams, "calendar_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + calendar, err := client.UpdateCalendar(ctx, grantID, p.CalendarID, &p.UpdateCalendarRequest) + if err != nil { + return nil, fmt.Errorf("calendar.update: %w", err) + } + return calendar, nil + }) + + d.Register("calendar.delete", func(ctx context.Context, params json.RawMessage) (any, error) { + var p calendarDeleteParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.CalendarID == "" { + return nil, NewRPCError(InvalidParams, "calendar_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + if err := client.DeleteCalendar(ctx, grantID, p.CalendarID); err != nil { + return nil, fmt.Errorf("calendar.delete: %w", err) + } + return deletedResult{Deleted: true}, nil + }) + + d.Register("calendar.freeBusy", func(ctx context.Context, params json.RawMessage) (any, error) { + var p calendarFreeBusyParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + result, err := client.GetFreeBusy(ctx, grantID, &p.FreeBusyRequest) + if err != nil { + return nil, fmt.Errorf("calendar.freeBusy: %w", err) + } + return result, nil + }) + + d.Register("calendar.availability", func(ctx context.Context, params json.RawMessage) (any, error) { + var p calendarAvailabilityParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + result, err := client.GetAvailability(ctx, &p.AvailabilityRequest) + if err != nil { + return nil, fmt.Errorf("calendar.availability: %w", err) + } + return result, nil + }) + + d.Register("calendar.resources", func(ctx context.Context, params json.RawMessage) (any, error) { + var p calendarResourcesParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + resources, err := client.ListRoomResources(ctx, grantID) + if err != nil { + return nil, fmt.Errorf("calendar.resources: %w", err) + } + return calendarResourcesResult{Resources: resources}, nil + }) + + d.Register("event.import", func(ctx context.Context, params json.RawMessage) (any, error) { + var p eventImportParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.CalendarID == "" { + return nil, NewRPCError(InvalidParams, "calendar_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + events, err := client.ImportEvents(ctx, grantID, &p.EventQueryParams) + if err != nil { + return nil, fmt.Errorf("event.import: %w", err) + } + return eventListPayloadResult{Events: events}, nil + }) + + d.Register("event.recurring.list", func(ctx context.Context, params json.RawMessage) (any, error) { + var p eventRecurringListParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.MasterEventID == "" { + return nil, NewRPCError(InvalidParams, "master_event_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + events, err := client.GetRecurringEventInstances(ctx, grantID, calendarIDOrPrimary(p.CalendarID), p.MasterEventID, &p.EventQueryParams) + if err != nil { + return nil, fmt.Errorf("event.recurring.list: %w", err) + } + return eventListPayloadResult{Events: events}, nil + }) + + d.Register("event.recurring.update", func(ctx context.Context, params json.RawMessage) (any, error) { + var p eventRecurringUpdateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.EventID == "" { + return nil, NewRPCError(InvalidParams, "event_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + event, err := client.UpdateRecurringEventInstance(ctx, grantID, calendarIDOrPrimary(p.CalendarID), p.EventID, &p.UpdateEventRequest) + if err != nil { + return nil, fmt.Errorf("event.recurring.update: %w", err) + } + return event, nil + }) + + d.Register("event.recurring.delete", func(ctx context.Context, params json.RawMessage) (any, error) { + var p eventRecurringDeleteParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.EventID == "" { + return nil, NewRPCError(InvalidParams, "event_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + if err := client.DeleteRecurringEventInstance(ctx, grantID, calendarIDOrPrimary(p.CalendarID), p.EventID); err != nil { + return nil, fmt.Errorf("event.recurring.delete: %w", err) + } + return deletedResult{Deleted: true}, nil + }) + + d.Register("calendar.virtual.create", func(ctx context.Context, params json.RawMessage) (any, error) { + var p virtualCalendarCreateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.Email == "" { + return nil, NewRPCError(InvalidParams, "email required", nil) + } + + grant, err := client.CreateVirtualCalendarGrant(ctx, p.Email) + if err != nil { + return nil, fmt.Errorf("calendar.virtual.create: %w", err) + } + return grant, nil + }) + + d.Register("calendar.virtual.list", func(ctx context.Context, _ json.RawMessage) (any, error) { + grants, err := client.ListVirtualCalendarGrants(ctx) + if err != nil { + return nil, fmt.Errorf("calendar.virtual.list: %w", err) + } + return virtualCalendarListResult{Grants: grants}, nil + }) + + d.Register("calendar.virtual.get", func(ctx context.Context, params json.RawMessage) (any, error) { + var p virtualCalendarIDParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.GrantID == "" { + return nil, NewRPCError(InvalidParams, "grant_id required", nil) + } + + grant, err := client.GetVirtualCalendarGrant(ctx, p.GrantID) + if err != nil { + return nil, fmt.Errorf("calendar.virtual.get: %w", err) + } + return grant, nil + }) + + d.Register("calendar.virtual.delete", func(ctx context.Context, params json.RawMessage) (any, error) { + var p virtualCalendarIDParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.GrantID == "" { + return nil, NewRPCError(InvalidParams, "grant_id required", nil) + } + + if err := client.DeleteVirtualCalendarGrant(ctx, p.GrantID); err != nil { + return nil, fmt.Errorf("calendar.virtual.delete: %w", err) + } + return deletedResult{Deleted: true}, nil + }) +} diff --git a/internal/adapters/rpcserver/handlers_calendar_ext_test.go b/internal/adapters/rpcserver/handlers_calendar_ext_test.go new file mode 100644 index 0000000..c72fe87 --- /dev/null +++ b/internal/adapters/rpcserver/handlers_calendar_ext_test.go @@ -0,0 +1,476 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type fakeCalendarExtClient struct { + ports.CalendarClient + + getCalendar func(context.Context, string, string) (*domain.Calendar, error) + createCalendar func(context.Context, string, *domain.CreateCalendarRequest) (*domain.Calendar, error) + updateCalendar func(context.Context, string, string, *domain.UpdateCalendarRequest) (*domain.Calendar, error) + deleteCalendar func(context.Context, string, string) error + getFreeBusy func(context.Context, string, *domain.FreeBusyRequest) (*domain.FreeBusyResponse, error) + getAvailability func(context.Context, *domain.AvailabilityRequest) (*domain.AvailabilityResponse, error) + listRoomResources func(context.Context, string) ([]domain.RoomResource, error) + importEvents func(context.Context, string, *domain.EventQueryParams) ([]domain.Event, error) + getRecurring func(context.Context, string, string, string, *domain.EventQueryParams) ([]domain.Event, error) + updateRecurring func(context.Context, string, string, string, *domain.UpdateEventRequest) (*domain.Event, error) + deleteRecurring func(context.Context, string, string, string) error + createVirtual func(context.Context, string) (*domain.VirtualCalendarGrant, error) + listVirtual func(context.Context) ([]domain.VirtualCalendarGrant, error) + getVirtual func(context.Context, string) (*domain.VirtualCalendarGrant, error) + deleteVirtual func(context.Context, string) error +} + +func (f *fakeCalendarExtClient) GetCalendar(ctx context.Context, grantID, calendarID string) (*domain.Calendar, error) { + if f.getCalendar == nil { + return nil, errors.New("unexpected GetCalendar") + } + return f.getCalendar(ctx, grantID, calendarID) +} + +func (f *fakeCalendarExtClient) CreateCalendar(ctx context.Context, grantID string, req *domain.CreateCalendarRequest) (*domain.Calendar, error) { + if f.createCalendar == nil { + return nil, errors.New("unexpected CreateCalendar") + } + return f.createCalendar(ctx, grantID, req) +} + +func (f *fakeCalendarExtClient) UpdateCalendar(ctx context.Context, grantID, calendarID string, req *domain.UpdateCalendarRequest) (*domain.Calendar, error) { + if f.updateCalendar == nil { + return nil, errors.New("unexpected UpdateCalendar") + } + return f.updateCalendar(ctx, grantID, calendarID, req) +} + +func (f *fakeCalendarExtClient) DeleteCalendar(ctx context.Context, grantID, calendarID string) error { + if f.deleteCalendar == nil { + return errors.New("unexpected DeleteCalendar") + } + return f.deleteCalendar(ctx, grantID, calendarID) +} + +func (f *fakeCalendarExtClient) GetFreeBusy(ctx context.Context, grantID string, req *domain.FreeBusyRequest) (*domain.FreeBusyResponse, error) { + if f.getFreeBusy == nil { + return nil, errors.New("unexpected GetFreeBusy") + } + return f.getFreeBusy(ctx, grantID, req) +} + +func (f *fakeCalendarExtClient) GetAvailability(ctx context.Context, req *domain.AvailabilityRequest) (*domain.AvailabilityResponse, error) { + if f.getAvailability == nil { + return nil, errors.New("unexpected GetAvailability") + } + return f.getAvailability(ctx, req) +} + +func (f *fakeCalendarExtClient) ListRoomResources(ctx context.Context, grantID string) ([]domain.RoomResource, error) { + if f.listRoomResources == nil { + return nil, errors.New("unexpected ListRoomResources") + } + return f.listRoomResources(ctx, grantID) +} + +func (f *fakeCalendarExtClient) ImportEvents(ctx context.Context, grantID string, params *domain.EventQueryParams) ([]domain.Event, error) { + if f.importEvents == nil { + return nil, errors.New("unexpected ImportEvents") + } + return f.importEvents(ctx, grantID, params) +} + +func (f *fakeCalendarExtClient) GetRecurringEventInstances(ctx context.Context, grantID, calendarID, masterEventID string, params *domain.EventQueryParams) ([]domain.Event, error) { + if f.getRecurring == nil { + return nil, errors.New("unexpected GetRecurringEventInstances") + } + return f.getRecurring(ctx, grantID, calendarID, masterEventID, params) +} + +func (f *fakeCalendarExtClient) UpdateRecurringEventInstance(ctx context.Context, grantID, calendarID, eventID string, req *domain.UpdateEventRequest) (*domain.Event, error) { + if f.updateRecurring == nil { + return nil, errors.New("unexpected UpdateRecurringEventInstance") + } + return f.updateRecurring(ctx, grantID, calendarID, eventID, req) +} + +func (f *fakeCalendarExtClient) DeleteRecurringEventInstance(ctx context.Context, grantID, calendarID, eventID string) error { + if f.deleteRecurring == nil { + return errors.New("unexpected DeleteRecurringEventInstance") + } + return f.deleteRecurring(ctx, grantID, calendarID, eventID) +} + +func (f *fakeCalendarExtClient) CreateVirtualCalendarGrant(ctx context.Context, email string) (*domain.VirtualCalendarGrant, error) { + if f.createVirtual == nil { + return nil, errors.New("unexpected CreateVirtualCalendarGrant") + } + return f.createVirtual(ctx, email) +} + +func (f *fakeCalendarExtClient) ListVirtualCalendarGrants(ctx context.Context) ([]domain.VirtualCalendarGrant, error) { + if f.listVirtual == nil { + return nil, errors.New("unexpected ListVirtualCalendarGrants") + } + return f.listVirtual(ctx) +} + +func (f *fakeCalendarExtClient) GetVirtualCalendarGrant(ctx context.Context, grantID string) (*domain.VirtualCalendarGrant, error) { + if f.getVirtual == nil { + return nil, errors.New("unexpected GetVirtualCalendarGrant") + } + return f.getVirtual(ctx, grantID) +} + +func (f *fakeCalendarExtClient) DeleteVirtualCalendarGrant(ctx context.Context, grantID string) error { + if f.deleteVirtual == nil { + return errors.New("unexpected DeleteVirtualCalendarGrant") + } + return f.deleteVirtual(ctx, grantID) +} + +func TestRegisterCalendarExtHandlers(t *testing.T) { + clientErr := errors.New("client unavailable") + + tests := []struct { + name string + method string + params string + defaultGrant string + client *fakeCalendarExtClient + assert func(*testing.T, rpcTestResponse) + }{ + { + name: "calendar.get returns calendar", + method: "calendar.get", + params: `{"calendar_id":"cal-1"}`, + defaultGrant: "default-grant", + client: &fakeCalendarExtClient{ + getCalendar: func(_ context.Context, grantID, calendarID string) (*domain.Calendar, error) { + if grantID != "default-grant" || calendarID != "cal-1" { + t.Fatalf("args = %q/%q, want default-grant/cal-1", grantID, calendarID) + } + return &domain.Calendar{ID: "cal-1"}, nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + var cal domain.Calendar + unmarshalResult(t, resp, &cal) + if cal.ID != "cal-1" { + t.Fatalf("calendar ID = %q, want cal-1", cal.ID) + } + }, + }, + { + name: "calendar.get missing calendar_id", + method: "calendar.get", + params: `{}`, + defaultGrant: "default-grant", + client: &fakeCalendarExtClient{}, + assert: func(t *testing.T, resp rpcTestResponse) { requireRPCErrorCode(t, resp, InvalidParams) }, + }, + { + name: "calendar.create returns calendar", + method: "calendar.create", + params: `{"name":"Team"}`, + defaultGrant: "default-grant", + client: &fakeCalendarExtClient{ + createCalendar: func(_ context.Context, _ string, _ *domain.CreateCalendarRequest) (*domain.Calendar, error) { + return &domain.Calendar{ID: "cal-new"}, nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + var cal domain.Calendar + unmarshalResult(t, resp, &cal) + if cal.ID != "cal-new" { + t.Fatalf("calendar ID = %q, want cal-new", cal.ID) + } + }, + }, + { + name: "calendar.update missing calendar_id", + method: "calendar.update", + params: `{"name":"x"}`, + defaultGrant: "default-grant", + client: &fakeCalendarExtClient{}, + assert: func(t *testing.T, resp rpcTestResponse) { requireRPCErrorCode(t, resp, InvalidParams) }, + }, + { + name: "calendar.delete returns deleted", + method: "calendar.delete", + params: `{"calendar_id":"cal-1"}`, + defaultGrant: "default-grant", + client: &fakeCalendarExtClient{ + deleteCalendar: func(_ context.Context, _, calendarID string) error { + if calendarID != "cal-1" { + t.Fatalf("calendarID = %q, want cal-1", calendarID) + } + return nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + var result deletedResult + unmarshalResult(t, resp, &result) + if !result.Deleted { + t.Fatal("deleted = false, want true") + } + }, + }, + { + name: "calendar.freeBusy returns response", + method: "calendar.freeBusy", + params: `{"start_time":100,"end_time":200,"emails":["a@example.com"]}`, + defaultGrant: "default-grant", + client: &fakeCalendarExtClient{ + getFreeBusy: func(_ context.Context, grantID string, _ *domain.FreeBusyRequest) (*domain.FreeBusyResponse, error) { + if grantID != "default-grant" { + t.Fatalf("grantID = %q, want default-grant", grantID) + } + return &domain.FreeBusyResponse{}, nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { requireNoRPCError(t, resp) }, + }, + { + name: "calendar.availability returns response without grant", + method: "calendar.availability", + params: `{"start_time":100,"end_time":200}`, + client: &fakeCalendarExtClient{ + getAvailability: func(context.Context, *domain.AvailabilityRequest) (*domain.AvailabilityResponse, error) { + return &domain.AvailabilityResponse{}, nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { requireNoRPCError(t, resp) }, + }, + { + name: "calendar.resources returns resources", + method: "calendar.resources", + params: `{}`, + defaultGrant: "default-grant", + client: &fakeCalendarExtClient{ + listRoomResources: func(context.Context, string) ([]domain.RoomResource, error) { + return []domain.RoomResource{{Email: "room@example.com", Name: "Big Room"}}, nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + var result calendarResourcesResult + unmarshalResult(t, resp, &result) + if len(result.Resources) != 1 || result.Resources[0].Name != "Big Room" { + t.Fatalf("resources = %+v, want one Big Room", result.Resources) + } + }, + }, + { + name: "event.import returns events", + method: "event.import", + params: `{"calendar_id":"cal-1","start":100,"end":200}`, + defaultGrant: "default-grant", + client: &fakeCalendarExtClient{ + importEvents: func(_ context.Context, _ string, params *domain.EventQueryParams) ([]domain.Event, error) { + if params.CalendarID != "cal-1" { + t.Fatalf("calendar_id = %q, want cal-1", params.CalendarID) + } + return []domain.Event{{ID: "ev-1"}}, nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + var result eventListPayloadResult + unmarshalResult(t, resp, &result) + if len(result.Events) != 1 || result.Events[0].ID != "ev-1" { + t.Fatalf("events = %+v, want one ev-1", result.Events) + } + }, + }, + { + name: "event.import missing calendar_id", + method: "event.import", + params: `{"start":100,"end":200}`, + defaultGrant: "default-grant", + client: &fakeCalendarExtClient{}, + assert: func(t *testing.T, resp rpcTestResponse) { requireRPCErrorCode(t, resp, InvalidParams) }, + }, + { + name: "event.recurring.list missing master_event_id", + method: "event.recurring.list", + params: `{}`, + defaultGrant: "default-grant", + client: &fakeCalendarExtClient{}, + assert: func(t *testing.T, resp rpcTestResponse) { requireRPCErrorCode(t, resp, InvalidParams) }, + }, + { + name: "event.recurring.list returns instances with default calendar", + method: "event.recurring.list", + params: `{"master_event_id":"master-1"}`, + defaultGrant: "default-grant", + client: &fakeCalendarExtClient{ + getRecurring: func(_ context.Context, _, calendarID, masterEventID string, _ *domain.EventQueryParams) ([]domain.Event, error) { + if calendarID != "primary" || masterEventID != "master-1" { + t.Fatalf("args = %q/%q, want primary/master-1", calendarID, masterEventID) + } + return []domain.Event{{ID: "inst-1"}}, nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + var result eventListPayloadResult + unmarshalResult(t, resp, &result) + if len(result.Events) != 1 { + t.Fatalf("events = %+v, want one", result.Events) + } + }, + }, + { + name: "event.recurring.update missing event_id", + method: "event.recurring.update", + params: `{}`, + defaultGrant: "default-grant", + client: &fakeCalendarExtClient{}, + assert: func(t *testing.T, resp rpcTestResponse) { requireRPCErrorCode(t, resp, InvalidParams) }, + }, + { + name: "event.recurring.delete returns deleted", + method: "event.recurring.delete", + params: `{"event_id":"ev-9"}`, + defaultGrant: "default-grant", + client: &fakeCalendarExtClient{ + deleteRecurring: func(_ context.Context, _, calendarID, eventID string) error { + if calendarID != "primary" || eventID != "ev-9" { + t.Fatalf("args = %q/%q, want primary/ev-9", calendarID, eventID) + } + return nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + var result deletedResult + unmarshalResult(t, resp, &result) + if !result.Deleted { + t.Fatal("deleted = false, want true") + } + }, + }, + { + name: "calendar.virtual.create requires email", + method: "calendar.virtual.create", + params: `{}`, + client: &fakeCalendarExtClient{}, + assert: func(t *testing.T, resp rpcTestResponse) { requireRPCErrorCode(t, resp, InvalidParams) }, + }, + { + name: "calendar.virtual.create returns grant", + method: "calendar.virtual.create", + params: `{"email":"vcal@example.com"}`, + client: &fakeCalendarExtClient{ + createVirtual: func(_ context.Context, email string) (*domain.VirtualCalendarGrant, error) { + if email != "vcal@example.com" { + t.Fatalf("email = %q, want vcal@example.com", email) + } + return &domain.VirtualCalendarGrant{ID: "vg-1", Email: email}, nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + var grant domain.VirtualCalendarGrant + unmarshalResult(t, resp, &grant) + if grant.ID != "vg-1" { + t.Fatalf("grant ID = %q, want vg-1", grant.ID) + } + }, + }, + { + name: "calendar.virtual.list returns grants", + method: "calendar.virtual.list", + params: `{}`, + client: &fakeCalendarExtClient{ + listVirtual: func(context.Context) ([]domain.VirtualCalendarGrant, error) { + return []domain.VirtualCalendarGrant{{ID: "vg-1"}}, nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + var result virtualCalendarListResult + unmarshalResult(t, resp, &result) + if len(result.Grants) != 1 { + t.Fatalf("grants = %+v, want one", result.Grants) + } + }, + }, + { + name: "calendar.virtual.get missing grant_id", + method: "calendar.virtual.get", + params: `{}`, + client: &fakeCalendarExtClient{}, + assert: func(t *testing.T, resp rpcTestResponse) { requireRPCErrorCode(t, resp, InvalidParams) }, + }, + { + name: "calendar.virtual.delete returns deleted", + method: "calendar.virtual.delete", + params: `{"grant_id":"vg-1"}`, + client: &fakeCalendarExtClient{ + deleteVirtual: func(_ context.Context, grantID string) error { + if grantID != "vg-1" { + t.Fatalf("grantID = %q, want vg-1", grantID) + } + return nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + var result deletedResult + unmarshalResult(t, resp, &result) + if !result.Deleted { + t.Fatal("deleted = false, want true") + } + }, + }, + { + name: "client error surfaces as internal error", + method: "calendar.get", + params: `{"calendar_id":"cal-1"}`, + defaultGrant: "default-grant", + client: &fakeCalendarExtClient{ + getCalendar: func(context.Context, string, string) (*domain.Calendar, error) { + return nil, clientErr + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { requireRPCErrorCode(t, resp, InternalError) }, + }, + { + name: "missing default grant errors", + method: "calendar.get", + params: `{"calendar_id":"cal-1"}`, + defaultGrant: "", + client: &fakeCalendarExtClient{}, + assert: func(t *testing.T, resp rpcTestResponse) { requireRPCErrorCode(t, resp, InvalidParams) }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewDispatcher() + RegisterCalendarExtHandlers(d, tt.client, tt.defaultGrant) + + raw := []byte(`{"jsonrpc":"2.0","id":1,"method":"` + tt.method + `","params":` + tt.params + `}`) + got := d.Dispatch(context.Background(), raw) + if got == nil { + t.Fatal("Dispatch() = nil, want response") + } + var resp rpcTestResponse + if err := json.Unmarshal(got, &resp); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + tt.assert(t, resp) + }) + } +} diff --git a/internal/adapters/rpcserver/handlers_calendar_test.go b/internal/adapters/rpcserver/handlers_calendar_test.go new file mode 100644 index 0000000..6f9d4de --- /dev/null +++ b/internal/adapters/rpcserver/handlers_calendar_test.go @@ -0,0 +1,227 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type fakeCalendarClient struct { + ports.CalendarClient + + getCalendars func(context.Context, string) ([]domain.Calendar, error) + getEventsWithCursor func(context.Context, string, string, *domain.EventQueryParams) (*domain.EventListResponse, error) + getEvent func(context.Context, string, string, string) (*domain.Event, error) +} + +func (f *fakeCalendarClient) GetCalendars(ctx context.Context, grantID string) ([]domain.Calendar, error) { + if f.getCalendars == nil { + return nil, errors.New("unexpected GetCalendars") + } + return f.getCalendars(ctx, grantID) +} + +func (f *fakeCalendarClient) GetEventsWithCursor(ctx context.Context, grantID, calendarID string, params *domain.EventQueryParams) (*domain.EventListResponse, error) { + if f.getEventsWithCursor == nil { + return nil, errors.New("unexpected GetEventsWithCursor") + } + return f.getEventsWithCursor(ctx, grantID, calendarID, params) +} + +func (f *fakeCalendarClient) GetEvent(ctx context.Context, grantID, calendarID, eventID string) (*domain.Event, error) { + if f.getEvent == nil { + return nil, errors.New("unexpected GetEvent") + } + return f.getEvent(ctx, grantID, calendarID, eventID) +} + +func TestRegisterCalendarHandlers(t *testing.T) { + clientErr := errors.New("client unavailable") + + tests := []struct { + name string + method string + params string + defaultGrant string + client *fakeCalendarClient + assert func(*testing.T, rpcTestResponse) + }{ + { + name: "calendar.list returns calendars", + method: "calendar.list", + params: `{}`, + defaultGrant: "default-grant", + client: &fakeCalendarClient{ + getCalendars: func(ctx context.Context, grantID string) ([]domain.Calendar, error) { + if grantID != "default-grant" { + t.Fatalf("grantID = %q, want %q", grantID, "default-grant") + } + return []domain.Calendar{{ID: "cal-1", Name: "Work"}}, nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + + var result struct { + Calendars []domain.Calendar `json:"calendars"` + } + unmarshalResult(t, resp, &result) + if len(result.Calendars) != 1 || result.Calendars[0].ID != "cal-1" { + t.Fatalf("calendars = %#v, want cal-1", result.Calendars) + } + }, + }, + { + name: "event.list defaults calendar and forwards query params", + method: "event.list", + params: `{"grant_id":"request-grant","limit":25,"page_token":"cursor-1","updated_after":1710000000,"start":1710000100,"end":1710000200}`, + defaultGrant: "default-grant", + client: &fakeCalendarClient{ + getEventsWithCursor: func(ctx context.Context, grantID, calendarID string, params *domain.EventQueryParams) (*domain.EventListResponse, error) { + if grantID != "request-grant" { + t.Fatalf("grantID = %q, want request-grant", grantID) + } + if calendarID != "primary" { + t.Fatalf("calendarID = %q, want primary", calendarID) + } + if params.Limit != 25 || params.PageToken != "cursor-1" || params.UpdatedAfter != 1710000000 || params.Start != 1710000100 || params.End != 1710000200 { + t.Fatalf("params = %+v, want forwarded query params", params) + } + return &domain.EventListResponse{ + Data: []domain.Event{{ID: "event-1", Title: "Sync"}}, + Pagination: domain.Pagination{NextCursor: "cursor-2", HasMore: true}, + }, nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + + var result struct { + Events []domain.Event `json:"events"` + NextCursor string `json:"next_cursor"` + HasMore bool `json:"has_more"` + } + unmarshalResult(t, resp, &result) + if len(result.Events) != 1 || result.Events[0].ID != "event-1" || result.NextCursor != "cursor-2" || !result.HasMore { + t.Fatalf("result = %+v, want event-1 cursor-2 has_more", result) + } + }, + }, + { + name: "event.list uses request calendar", + method: "event.list", + params: `{"calendar_id":"cal-1"}`, + defaultGrant: "default-grant", + client: &fakeCalendarClient{ + getEventsWithCursor: func(ctx context.Context, grantID, calendarID string, params *domain.EventQueryParams) (*domain.EventListResponse, error) { + if calendarID != "cal-1" { + t.Fatalf("calendarID = %q, want cal-1", calendarID) + } + return &domain.EventListResponse{}, nil + }, + }, + assert: requireNoRPCError, + }, + { + name: "event.get returns event", + method: "event.get", + params: `{"grant_id":"grant-1","calendar_id":"cal-1","event_id":"event-1"}`, + client: &fakeCalendarClient{ + getEvent: func(ctx context.Context, grantID, calendarID, eventID string) (*domain.Event, error) { + if grantID != "grant-1" || calendarID != "cal-1" || eventID != "event-1" { + t.Fatalf("args = %q %q %q, want grant-1 cal-1 event-1", grantID, calendarID, eventID) + } + return &domain.Event{ID: "event-1", Title: "Sync"}, nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + + var event domain.Event + unmarshalResult(t, resp, &event) + if event.ID != "event-1" || event.Title != "Sync" { + t.Fatalf("event = %#v, want event-1 Sync", event) + } + }, + }, + { + name: "calendar.list missing grant returns invalid params", + method: "calendar.list", + params: `{}`, + client: &fakeCalendarClient{}, + assert: func(t *testing.T, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + }, + }, + { + name: "event.get missing event_id returns invalid params", + method: "event.get", + params: `{}`, + defaultGrant: "default-grant", + client: &fakeCalendarClient{}, + assert: func(t *testing.T, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + }, + }, + { + name: "malformed params returns invalid params", + method: "event.list", + params: `"bad"`, + defaultGrant: "default-grant", + client: &fakeCalendarClient{}, + assert: func(t *testing.T, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + }, + }, + { + name: "client error maps to internal error", + method: "event.get", + params: `{"event_id":"event-1"}`, + defaultGrant: "default-grant", + client: &fakeCalendarClient{ + getEvent: func(ctx context.Context, grantID, calendarID, eventID string) (*domain.Event, error) { + if calendarID != "primary" { + t.Fatalf("calendarID = %q, want primary", calendarID) + } + return nil, clientErr + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InternalError) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewDispatcher() + RegisterCalendarHandlers(d, tt.client, tt.defaultGrant) + + resp := dispatchCalendarRequest(t, d, tt.method, tt.params) + tt.assert(t, resp) + }) + } +} + +func dispatchCalendarRequest(t *testing.T, d *Dispatcher, method, params string) rpcTestResponse { + t.Helper() + + raw := []byte(`{"jsonrpc":"2.0","id":1,"method":"` + method + `","params":` + params + `}`) + got := d.Dispatch(context.Background(), raw) + if got == nil { + t.Fatal("Dispatch() = nil, want response") + } + + var resp rpcTestResponse + if err := json.Unmarshal(got, &resp); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + if resp.JSONRPC != "2.0" { + t.Fatalf("JSONRPC = %q, want %q", resp.JSONRPC, "2.0") + } + return resp +} diff --git a/internal/adapters/rpcserver/handlers_calendar_write.go b/internal/adapters/rpcserver/handlers_calendar_write.go new file mode 100644 index 0000000..3b084c1 --- /dev/null +++ b/internal/adapters/rpcserver/handlers_calendar_write.go @@ -0,0 +1,128 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type eventCreateParams struct { + GrantID string `json:"grant_id,omitempty"` + CalendarID string `json:"calendar_id,omitempty"` + domain.CreateEventRequest +} + +type eventUpdateParams struct { + GrantID string `json:"grant_id,omitempty"` + CalendarID string `json:"calendar_id,omitempty"` + EventID string `json:"event_id"` + domain.UpdateEventRequest +} + +type eventDeleteParams struct { + GrantID string `json:"grant_id,omitempty"` + CalendarID string `json:"calendar_id,omitempty"` + EventID string `json:"event_id"` +} + +type eventRSVPParams struct { + GrantID string `json:"grant_id,omitempty"` + CalendarID string `json:"calendar_id,omitempty"` + EventID string `json:"event_id"` + domain.SendRSVPRequest +} + +type eventRSVPResult struct { + OK bool `json:"ok"` +} + +func RegisterCalendarWriteHandlers(d *Dispatcher, client ports.CalendarClient, defaultGrant string) { + d.Register("event.create", func(ctx context.Context, params json.RawMessage) (any, error) { + var p eventCreateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + event, err := client.CreateEvent(ctx, grantID, calendarIDOrPrimary(p.CalendarID), &p.CreateEventRequest) + if err != nil { + return nil, fmt.Errorf("event.create: %w", err) + } + return event, nil + }) + + d.Register("event.update", func(ctx context.Context, params json.RawMessage) (any, error) { + var p eventUpdateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.EventID == "" { + return nil, NewRPCError(InvalidParams, "event_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + event, err := client.UpdateEvent(ctx, grantID, calendarIDOrPrimary(p.CalendarID), p.EventID, &p.UpdateEventRequest) + if err != nil { + return nil, fmt.Errorf("event.update: %w", err) + } + return event, nil + }) + + d.Register("event.delete", func(ctx context.Context, params json.RawMessage) (any, error) { + var p eventDeleteParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.EventID == "" { + return nil, NewRPCError(InvalidParams, "event_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + if err := client.DeleteEvent(ctx, grantID, calendarIDOrPrimary(p.CalendarID), p.EventID); err != nil { + return nil, fmt.Errorf("event.delete: %w", err) + } + return deletedResult{Deleted: true}, nil + }) + + d.Register("event.rsvp", func(ctx context.Context, params json.RawMessage) (any, error) { + var p eventRSVPParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.EventID == "" { + return nil, NewRPCError(InvalidParams, "event_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + if err := client.SendRSVP(ctx, grantID, calendarIDOrPrimary(p.CalendarID), p.EventID, &p.SendRSVPRequest); err != nil { + return nil, fmt.Errorf("event.rsvp: %w", err) + } + return eventRSVPResult{OK: true}, nil + }) +} + +func calendarIDOrPrimary(calendarID string) string { + if calendarID != "" { + return calendarID + } + return "primary" +} diff --git a/internal/adapters/rpcserver/handlers_calendar_write_test.go b/internal/adapters/rpcserver/handlers_calendar_write_test.go new file mode 100644 index 0000000..85e9604 --- /dev/null +++ b/internal/adapters/rpcserver/handlers_calendar_write_test.go @@ -0,0 +1,283 @@ +package rpcserver + +import ( + "context" + "errors" + "testing" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type fakeCalendarWriteClient struct { + ports.CalendarClient + + createEvent func(context.Context, string, string, *domain.CreateEventRequest) (*domain.Event, error) + updateEvent func(context.Context, string, string, string, *domain.UpdateEventRequest) (*domain.Event, error) + deleteEvent func(context.Context, string, string, string) error + sendRSVP func(context.Context, string, string, string, *domain.SendRSVPRequest) error + + createGrantID string + createCalendarID string + createReq domain.CreateEventRequest + updateGrantID string + updateCalendarID string + updateEventID string + updateReq domain.UpdateEventRequest + deleteGrantID string + deleteCalendarID string + deleteEventID string + rsvpGrantID string + rsvpCalendarID string + rsvpEventID string + rsvpReq domain.SendRSVPRequest +} + +func (f *fakeCalendarWriteClient) CreateEvent(ctx context.Context, grantID, calendarID string, req *domain.CreateEventRequest) (*domain.Event, error) { + f.createGrantID = grantID + f.createCalendarID = calendarID + if req != nil { + f.createReq = *req + } + if f.createEvent == nil { + return nil, errors.New("unexpected CreateEvent") + } + return f.createEvent(ctx, grantID, calendarID, req) +} + +func (f *fakeCalendarWriteClient) UpdateEvent(ctx context.Context, grantID, calendarID, eventID string, req *domain.UpdateEventRequest) (*domain.Event, error) { + f.updateGrantID = grantID + f.updateCalendarID = calendarID + f.updateEventID = eventID + if req != nil { + f.updateReq = *req + } + if f.updateEvent == nil { + return nil, errors.New("unexpected UpdateEvent") + } + return f.updateEvent(ctx, grantID, calendarID, eventID, req) +} + +func (f *fakeCalendarWriteClient) DeleteEvent(ctx context.Context, grantID, calendarID, eventID string) error { + f.deleteGrantID = grantID + f.deleteCalendarID = calendarID + f.deleteEventID = eventID + if f.deleteEvent == nil { + return errors.New("unexpected DeleteEvent") + } + return f.deleteEvent(ctx, grantID, calendarID, eventID) +} + +func (f *fakeCalendarWriteClient) SendRSVP(ctx context.Context, grantID, calendarID, eventID string, req *domain.SendRSVPRequest) error { + f.rsvpGrantID = grantID + f.rsvpCalendarID = calendarID + f.rsvpEventID = eventID + if req != nil { + f.rsvpReq = *req + } + if f.sendRSVP == nil { + return errors.New("unexpected SendRSVP") + } + return f.sendRSVP(ctx, grantID, calendarID, eventID, req) +} + +func TestRegisterCalendarWriteHandlers(t *testing.T) { + clientErr := errors.New("client unavailable") + + title := "Updated sync" + busy := true + + tests := []struct { + name string + method string + params string + defaultGrant string + client *fakeCalendarWriteClient + assert func(*testing.T, *fakeCalendarWriteClient, rpcTestResponse) + }{ + { + name: "event.create defaults calendar and forwards embedded request", + method: "event.create", + params: `{"title":"Sync","description":"Weekly","location":"Room 1","when":{"start_time":1710000000,"end_time":1710003600},"participants":[{"email":"ada@example.com","name":"Ada"}],"busy":true,"visibility":"private","recurrence":["RRULE:FREQ=WEEKLY"],"metadata":{"source":"rpc"}}`, + defaultGrant: "default-grant", + client: &fakeCalendarWriteClient{ + createEvent: func(ctx context.Context, grantID, calendarID string, req *domain.CreateEventRequest) (*domain.Event, error) { + return &domain.Event{ID: "event-1", Title: req.Title}, nil + }, + }, + assert: func(t *testing.T, client *fakeCalendarWriteClient, resp rpcTestResponse) { + requireNoRPCError(t, resp) + + if client.createGrantID != "default-grant" || client.createCalendarID != "primary" { + t.Fatalf("create args = %q %q, want default-grant primary", client.createGrantID, client.createCalendarID) + } + if client.createReq.Title != "Sync" || client.createReq.Description != "Weekly" || client.createReq.Location != "Room 1" { + t.Fatalf("create request = %+v, want embedded string fields", client.createReq) + } + if client.createReq.When.StartTime != 1710000000 || client.createReq.When.EndTime != 1710003600 { + t.Fatalf("create when = %+v, want forwarded times", client.createReq.When) + } + if len(client.createReq.Participants) != 1 || client.createReq.Participants[0].Email != "ada@example.com" || client.createReq.Participants[0].Name != "Ada" { + t.Fatalf("participants = %+v, want Ada", client.createReq.Participants) + } + if !client.createReq.Busy || client.createReq.Visibility != "private" || len(client.createReq.Recurrence) != 1 || client.createReq.Metadata["source"] != "rpc" { + t.Fatalf("create request = %+v, want busy visibility recurrence metadata", client.createReq) + } + + var event domain.Event + unmarshalResult(t, resp, &event) + if event.ID != "event-1" || event.Title != "Sync" { + t.Fatalf("event = %+v, want event-1 Sync", event) + } + }, + }, + { + name: "event.create missing grant returns invalid params", + method: "event.create", + params: `{"title":"Sync"}`, + client: &fakeCalendarWriteClient{}, + assert: func(t *testing.T, client *fakeCalendarWriteClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + if client.createGrantID != "" { + t.Fatalf("CreateEvent called with grant %q, want no call", client.createGrantID) + } + }, + }, + { + name: "event.update requires event_id", + method: "event.update", + params: `{"title":"Sync"}`, + defaultGrant: "default-grant", + client: &fakeCalendarWriteClient{}, + assert: func(t *testing.T, client *fakeCalendarWriteClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + if client.updateEventID != "" { + t.Fatalf("UpdateEvent called with event %q, want no call", client.updateEventID) + } + }, + }, + { + name: "event.update forwards embedded request", + method: "event.update", + params: `{"grant_id":"grant-1","calendar_id":"cal-1","event_id":"event-1","title":"Updated sync","busy":true,"metadata":{"source":"rpc"}}`, + defaultGrant: "default-grant", + client: &fakeCalendarWriteClient{ + updateEvent: func(ctx context.Context, grantID, calendarID, eventID string, req *domain.UpdateEventRequest) (*domain.Event, error) { + return &domain.Event{ID: eventID, Title: *req.Title}, nil + }, + }, + assert: func(t *testing.T, client *fakeCalendarWriteClient, resp rpcTestResponse) { + requireNoRPCError(t, resp) + + if client.updateGrantID != "grant-1" || client.updateCalendarID != "cal-1" || client.updateEventID != "event-1" { + t.Fatalf("update args = %q %q %q, want grant-1 cal-1 event-1", client.updateGrantID, client.updateCalendarID, client.updateEventID) + } + if client.updateReq.Title == nil || *client.updateReq.Title != title { + t.Fatalf("Title = %v, want %q", client.updateReq.Title, title) + } + if client.updateReq.Busy == nil || *client.updateReq.Busy != busy || client.updateReq.Metadata["source"] != "rpc" { + t.Fatalf("update request = %+v, want busy metadata", client.updateReq) + } + }, + }, + { + name: "event.delete requires event_id", + method: "event.delete", + params: `{}`, + defaultGrant: "default-grant", + client: &fakeCalendarWriteClient{}, + assert: func(t *testing.T, client *fakeCalendarWriteClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + if client.deleteEventID != "" { + t.Fatalf("DeleteEvent called with event %q, want no call", client.deleteEventID) + } + }, + }, + { + name: "event.delete returns deleted", + method: "event.delete", + params: `{"grant_id":"grant-1","calendar_id":"cal-1","event_id":"event-1"}`, + defaultGrant: "default-grant", + client: &fakeCalendarWriteClient{ + deleteEvent: func(ctx context.Context, grantID, calendarID, eventID string) error { + return nil + }, + }, + assert: func(t *testing.T, client *fakeCalendarWriteClient, resp rpcTestResponse) { + requireNoRPCError(t, resp) + + if client.deleteGrantID != "grant-1" || client.deleteCalendarID != "cal-1" || client.deleteEventID != "event-1" { + t.Fatalf("delete args = %q %q %q, want grant-1 cal-1 event-1", client.deleteGrantID, client.deleteCalendarID, client.deleteEventID) + } + var result deletedResult + unmarshalResult(t, resp, &result) + if !result.Deleted { + t.Fatal("deleted = false, want true") + } + }, + }, + { + name: "event.rsvp requires event_id", + method: "event.rsvp", + params: `{"status":"yes"}`, + defaultGrant: "default-grant", + client: &fakeCalendarWriteClient{}, + assert: func(t *testing.T, client *fakeCalendarWriteClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + if client.rsvpEventID != "" { + t.Fatalf("SendRSVP called with event %q, want no call", client.rsvpEventID) + } + }, + }, + { + name: "event.rsvp forwards embedded request and returns ok", + method: "event.rsvp", + params: `{"grant_id":"grant-1","calendar_id":"cal-1","event_id":"event-1","status":"yes","comment":"See you there"}`, + defaultGrant: "default-grant", + client: &fakeCalendarWriteClient{ + sendRSVP: func(ctx context.Context, grantID, calendarID, eventID string, req *domain.SendRSVPRequest) error { + return nil + }, + }, + assert: func(t *testing.T, client *fakeCalendarWriteClient, resp rpcTestResponse) { + requireNoRPCError(t, resp) + + if client.rsvpGrantID != "grant-1" || client.rsvpCalendarID != "cal-1" || client.rsvpEventID != "event-1" { + t.Fatalf("rsvp args = %q %q %q, want grant-1 cal-1 event-1", client.rsvpGrantID, client.rsvpCalendarID, client.rsvpEventID) + } + if client.rsvpReq.Status != "yes" || client.rsvpReq.Comment != "See you there" { + t.Fatalf("rsvp request = %+v, want yes with comment", client.rsvpReq) + } + var result eventRSVPResult + unmarshalResult(t, resp, &result) + if !result.OK { + t.Fatal("ok = false, want true") + } + }, + }, + { + name: "client error maps to internal error", + method: "event.create", + params: `{"title":"Sync"}`, + defaultGrant: "default-grant", + client: &fakeCalendarWriteClient{ + createEvent: func(ctx context.Context, grantID, calendarID string, req *domain.CreateEventRequest) (*domain.Event, error) { + return nil, clientErr + }, + }, + assert: func(t *testing.T, client *fakeCalendarWriteClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InternalError) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewDispatcher() + RegisterCalendarWriteHandlers(d, tt.client, tt.defaultGrant) + + resp := dispatchCalendarRequest(t, d, tt.method, tt.params) + tt.assert(t, tt.client, resp) + }) + } +} diff --git a/internal/adapters/rpcserver/handlers_contact_ext.go b/internal/adapters/rpcserver/handlers_contact_ext.go new file mode 100644 index 0000000..b0139c2 --- /dev/null +++ b/internal/adapters/rpcserver/handlers_contact_ext.go @@ -0,0 +1,168 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type contactGroupListParams struct { + GrantID string `json:"grant_id,omitempty"` +} + +type contactGroupListResult struct { + Groups []domain.ContactGroup `json:"groups"` +} + +type contactGroupGetParams struct { + GrantID string `json:"grant_id,omitempty"` + GroupID string `json:"group_id"` +} + +type contactGroupCreateParams struct { + GrantID string `json:"grant_id,omitempty"` + domain.CreateContactGroupRequest +} + +type contactGroupUpdateParams struct { + GrantID string `json:"grant_id,omitempty"` + GroupID string `json:"group_id"` + domain.UpdateContactGroupRequest +} + +type contactGroupDeleteParams struct { + GrantID string `json:"grant_id,omitempty"` + GroupID string `json:"group_id"` +} + +type contactGetPictureParams struct { + GrantID string `json:"grant_id,omitempty"` + ContactID string `json:"contact_id"` + IncludePicture bool `json:"include_picture,omitempty"` +} + +// RegisterContactExtHandlers registers contact group CRUD and the +// picture-bearing contact read. +func RegisterContactExtHandlers(d *Dispatcher, client ports.ContactClient, defaultGrant string) { + d.Register("contact.group.list", func(ctx context.Context, params json.RawMessage) (any, error) { + var p contactGroupListParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + groups, err := client.GetContactGroups(ctx, grantID) + if err != nil { + return nil, fmt.Errorf("contact.group.list: %w", err) + } + return contactGroupListResult{Groups: groups}, nil + }) + + d.Register("contact.group.get", func(ctx context.Context, params json.RawMessage) (any, error) { + var p contactGroupGetParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.GroupID == "" { + return nil, NewRPCError(InvalidParams, "group_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + group, err := client.GetContactGroup(ctx, grantID, p.GroupID) + if err != nil { + return nil, fmt.Errorf("contact.group.get: %w", err) + } + return group, nil + }) + + d.Register("contact.group.create", func(ctx context.Context, params json.RawMessage) (any, error) { + var p contactGroupCreateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + group, err := client.CreateContactGroup(ctx, grantID, &p.CreateContactGroupRequest) + if err != nil { + return nil, fmt.Errorf("contact.group.create: %w", err) + } + return group, nil + }) + + d.Register("contact.group.update", func(ctx context.Context, params json.RawMessage) (any, error) { + var p contactGroupUpdateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.GroupID == "" { + return nil, NewRPCError(InvalidParams, "group_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + group, err := client.UpdateContactGroup(ctx, grantID, p.GroupID, &p.UpdateContactGroupRequest) + if err != nil { + return nil, fmt.Errorf("contact.group.update: %w", err) + } + return group, nil + }) + + d.Register("contact.group.delete", func(ctx context.Context, params json.RawMessage) (any, error) { + var p contactGroupDeleteParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.GroupID == "" { + return nil, NewRPCError(InvalidParams, "group_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + if err := client.DeleteContactGroup(ctx, grantID, p.GroupID); err != nil { + return nil, fmt.Errorf("contact.group.delete: %w", err) + } + return deletedResult{Deleted: true}, nil + }) + + d.Register("contact.getWithPicture", func(ctx context.Context, params json.RawMessage) (any, error) { + var p contactGetPictureParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.ContactID == "" { + return nil, NewRPCError(InvalidParams, "contact_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + contact, err := client.GetContactWithPicture(ctx, grantID, p.ContactID, p.IncludePicture) + if err != nil { + return nil, fmt.Errorf("contact.getWithPicture: %w", err) + } + return contact, nil + }) +} diff --git a/internal/adapters/rpcserver/handlers_contact_ext_test.go b/internal/adapters/rpcserver/handlers_contact_ext_test.go new file mode 100644 index 0000000..be0d6ee --- /dev/null +++ b/internal/adapters/rpcserver/handlers_contact_ext_test.go @@ -0,0 +1,263 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type fakeContactExtClient struct { + ports.ContactClient + + getContactGroups func(context.Context, string) ([]domain.ContactGroup, error) + getContactGroup func(context.Context, string, string) (*domain.ContactGroup, error) + createContactGroup func(context.Context, string, *domain.CreateContactGroupRequest) (*domain.ContactGroup, error) + updateContactGroup func(context.Context, string, string, *domain.UpdateContactGroupRequest) (*domain.ContactGroup, error) + deleteContactGroup func(context.Context, string, string) error + getContactPicture func(context.Context, string, string, bool) (*domain.Contact, error) + + grantIDs []string +} + +func (f *fakeContactExtClient) GetContactGroups(ctx context.Context, grantID string) ([]domain.ContactGroup, error) { + f.grantIDs = append(f.grantIDs, grantID) + if f.getContactGroups == nil { + return nil, errors.New("unexpected GetContactGroups") + } + return f.getContactGroups(ctx, grantID) +} + +func (f *fakeContactExtClient) GetContactGroup(ctx context.Context, grantID, groupID string) (*domain.ContactGroup, error) { + if f.getContactGroup == nil { + return nil, errors.New("unexpected GetContactGroup") + } + return f.getContactGroup(ctx, grantID, groupID) +} + +func (f *fakeContactExtClient) CreateContactGroup(ctx context.Context, grantID string, req *domain.CreateContactGroupRequest) (*domain.ContactGroup, error) { + if f.createContactGroup == nil { + return nil, errors.New("unexpected CreateContactGroup") + } + return f.createContactGroup(ctx, grantID, req) +} + +func (f *fakeContactExtClient) UpdateContactGroup(ctx context.Context, grantID, groupID string, req *domain.UpdateContactGroupRequest) (*domain.ContactGroup, error) { + if f.updateContactGroup == nil { + return nil, errors.New("unexpected UpdateContactGroup") + } + return f.updateContactGroup(ctx, grantID, groupID, req) +} + +func (f *fakeContactExtClient) DeleteContactGroup(ctx context.Context, grantID, groupID string) error { + if f.deleteContactGroup == nil { + return errors.New("unexpected DeleteContactGroup") + } + return f.deleteContactGroup(ctx, grantID, groupID) +} + +func (f *fakeContactExtClient) GetContactWithPicture(ctx context.Context, grantID, contactID string, includePicture bool) (*domain.Contact, error) { + if f.getContactPicture == nil { + return nil, errors.New("unexpected GetContactWithPicture") + } + return f.getContactPicture(ctx, grantID, contactID, includePicture) +} + +func TestRegisterContactExtHandlers(t *testing.T) { + clientErr := errors.New("client unavailable") + + tests := []struct { + name string + method string + params string + defaultGrant string + client *fakeContactExtClient + assert func(*testing.T, *fakeContactExtClient, rpcTestResponse) + }{ + { + name: "contact.group.list returns groups", + method: "contact.group.list", + params: `{}`, + defaultGrant: "default-grant", + client: &fakeContactExtClient{ + getContactGroups: func(_ context.Context, grantID string) ([]domain.ContactGroup, error) { + if grantID != "default-grant" { + t.Fatalf("grantID = %q, want default-grant", grantID) + } + return []domain.ContactGroup{{ID: "grp-1", Name: "Friends"}}, nil + }, + }, + assert: func(t *testing.T, _ *fakeContactExtClient, resp rpcTestResponse) { + requireNoRPCError(t, resp) + var result contactGroupListResult + unmarshalResult(t, resp, &result) + if len(result.Groups) != 1 || result.Groups[0].ID != "grp-1" { + t.Fatalf("groups = %+v, want one grp-1", result.Groups) + } + }, + }, + { + name: "contact.group.list without grant errors", + method: "contact.group.list", + params: `{}`, + defaultGrant: "", + client: &fakeContactExtClient{}, + assert: func(t *testing.T, _ *fakeContactExtClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + }, + }, + { + name: "contact.group.get returns group", + method: "contact.group.get", + params: `{"group_id":"grp-9"}`, + defaultGrant: "default-grant", + client: &fakeContactExtClient{ + getContactGroup: func(_ context.Context, _, groupID string) (*domain.ContactGroup, error) { + if groupID != "grp-9" { + t.Fatalf("groupID = %q, want grp-9", groupID) + } + return &domain.ContactGroup{ID: "grp-9", Name: "Work"}, nil + }, + }, + assert: func(t *testing.T, _ *fakeContactExtClient, resp rpcTestResponse) { + requireNoRPCError(t, resp) + var group domain.ContactGroup + unmarshalResult(t, resp, &group) + if group.ID != "grp-9" { + t.Fatalf("group ID = %q, want grp-9", group.ID) + } + }, + }, + { + name: "contact.group.get missing group_id", + method: "contact.group.get", + params: `{}`, + defaultGrant: "default-grant", + client: &fakeContactExtClient{}, + assert: func(t *testing.T, _ *fakeContactExtClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + }, + }, + { + name: "contact.group.create returns group", + method: "contact.group.create", + params: `{"name":"New Group"}`, + defaultGrant: "default-grant", + client: &fakeContactExtClient{ + createContactGroup: func(_ context.Context, _ string, req *domain.CreateContactGroupRequest) (*domain.ContactGroup, error) { + if req.Name != "New Group" { + t.Fatalf("name = %q, want New Group", req.Name) + } + return &domain.ContactGroup{ID: "grp-new", Name: req.Name}, nil + }, + }, + assert: func(t *testing.T, _ *fakeContactExtClient, resp rpcTestResponse) { + requireNoRPCError(t, resp) + var group domain.ContactGroup + unmarshalResult(t, resp, &group) + if group.ID != "grp-new" { + t.Fatalf("group ID = %q, want grp-new", group.ID) + } + }, + }, + { + name: "contact.group.update missing group_id", + method: "contact.group.update", + params: `{"name":"x"}`, + defaultGrant: "default-grant", + client: &fakeContactExtClient{}, + assert: func(t *testing.T, _ *fakeContactExtClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + }, + }, + { + name: "contact.group.delete returns deleted", + method: "contact.group.delete", + params: `{"group_id":"grp-9"}`, + defaultGrant: "default-grant", + client: &fakeContactExtClient{ + deleteContactGroup: func(_ context.Context, _, groupID string) error { + if groupID != "grp-9" { + t.Fatalf("groupID = %q, want grp-9", groupID) + } + return nil + }, + }, + assert: func(t *testing.T, _ *fakeContactExtClient, resp rpcTestResponse) { + requireNoRPCError(t, resp) + var result deletedResult + unmarshalResult(t, resp, &result) + if !result.Deleted { + t.Fatal("deleted = false, want true") + } + }, + }, + { + name: "contact.getWithPicture passes include flag", + method: "contact.getWithPicture", + params: `{"contact_id":"c-1","include_picture":true}`, + defaultGrant: "default-grant", + client: &fakeContactExtClient{ + getContactPicture: func(_ context.Context, _, contactID string, includePicture bool) (*domain.Contact, error) { + if contactID != "c-1" || !includePicture { + t.Fatalf("args = %q/%v, want c-1/true", contactID, includePicture) + } + return &domain.Contact{ID: "c-1", Picture: "base64data"}, nil + }, + }, + assert: func(t *testing.T, _ *fakeContactExtClient, resp rpcTestResponse) { + requireNoRPCError(t, resp) + var contact domain.Contact + unmarshalResult(t, resp, &contact) + if contact.ID != "c-1" || contact.Picture == "" { + t.Fatalf("contact = %+v, want c-1 with picture", contact) + } + }, + }, + { + name: "contact.getWithPicture missing contact_id", + method: "contact.getWithPicture", + params: `{}`, + defaultGrant: "default-grant", + client: &fakeContactExtClient{}, + assert: func(t *testing.T, _ *fakeContactExtClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + }, + }, + { + name: "client error surfaces as internal error", + method: "contact.group.list", + params: `{}`, + defaultGrant: "default-grant", + client: &fakeContactExtClient{ + getContactGroups: func(context.Context, string) ([]domain.ContactGroup, error) { + return nil, clientErr + }, + }, + assert: func(t *testing.T, _ *fakeContactExtClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InternalError) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewDispatcher() + RegisterContactExtHandlers(d, tt.client, tt.defaultGrant) + + raw := []byte(`{"jsonrpc":"2.0","id":1,"method":"` + tt.method + `","params":` + tt.params + `}`) + got := d.Dispatch(context.Background(), raw) + if got == nil { + t.Fatal("Dispatch() = nil, want response") + } + var resp rpcTestResponse + if err := json.Unmarshal(got, &resp); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + tt.assert(t, tt.client, resp) + }) + } +} diff --git a/internal/adapters/rpcserver/handlers_contact_write.go b/internal/adapters/rpcserver/handlers_contact_write.go new file mode 100644 index 0000000..2983359 --- /dev/null +++ b/internal/adapters/rpcserver/handlers_contact_write.go @@ -0,0 +1,87 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type contactCreateParams struct { + GrantID string `json:"grant_id,omitempty"` + domain.CreateContactRequest +} + +type contactUpdateParams struct { + GrantID string `json:"grant_id,omitempty"` + ContactID string `json:"contact_id"` + domain.UpdateContactRequest +} + +type contactDeleteParams struct { + GrantID string `json:"grant_id,omitempty"` + ContactID string `json:"contact_id"` +} + +func RegisterContactWriteHandlers(d *Dispatcher, client ports.ContactClient, defaultGrant string) { + d.Register("contact.create", func(ctx context.Context, params json.RawMessage) (any, error) { + var p contactCreateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + contact, err := client.CreateContact(ctx, grantID, &p.CreateContactRequest) + if err != nil { + return nil, fmt.Errorf("contact.create: %w", err) + } + return contact, nil + }) + + d.Register("contact.update", func(ctx context.Context, params json.RawMessage) (any, error) { + var p contactUpdateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.ContactID == "" { + return nil, NewRPCError(InvalidParams, "contact_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + contact, err := client.UpdateContact(ctx, grantID, p.ContactID, &p.UpdateContactRequest) + if err != nil { + return nil, fmt.Errorf("contact.update: %w", err) + } + return contact, nil + }) + + d.Register("contact.delete", func(ctx context.Context, params json.RawMessage) (any, error) { + var p contactDeleteParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.ContactID == "" { + return nil, NewRPCError(InvalidParams, "contact_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + if err := client.DeleteContact(ctx, grantID, p.ContactID); err != nil { + return nil, fmt.Errorf("contact.delete: %w", err) + } + return deletedResult{Deleted: true}, nil + }) +} diff --git a/internal/adapters/rpcserver/handlers_contact_write_test.go b/internal/adapters/rpcserver/handlers_contact_write_test.go new file mode 100644 index 0000000..848d8e1 --- /dev/null +++ b/internal/adapters/rpcserver/handlers_contact_write_test.go @@ -0,0 +1,205 @@ +package rpcserver + +import ( + "context" + "errors" + "testing" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type fakeContactWriteClient struct { + ports.ContactClient + + createContact func(context.Context, string, *domain.CreateContactRequest) (*domain.Contact, error) + updateContact func(context.Context, string, string, *domain.UpdateContactRequest) (*domain.Contact, error) + deleteContact func(context.Context, string, string) error + + createGrantID string + createReq *domain.CreateContactRequest + updateGrantID string + updateID string + updateReq *domain.UpdateContactRequest + deleteGrantID string + deleteID string +} + +func (f *fakeContactWriteClient) CreateContact(ctx context.Context, grantID string, req *domain.CreateContactRequest) (*domain.Contact, error) { + f.createGrantID = grantID + f.createReq = req + if f.createContact == nil { + return nil, errors.New("unexpected CreateContact") + } + return f.createContact(ctx, grantID, req) +} + +func (f *fakeContactWriteClient) UpdateContact(ctx context.Context, grantID, contactID string, req *domain.UpdateContactRequest) (*domain.Contact, error) { + f.updateGrantID = grantID + f.updateID = contactID + f.updateReq = req + if f.updateContact == nil { + return nil, errors.New("unexpected UpdateContact") + } + return f.updateContact(ctx, grantID, contactID, req) +} + +func (f *fakeContactWriteClient) DeleteContact(ctx context.Context, grantID, contactID string) error { + f.deleteGrantID = grantID + f.deleteID = contactID + if f.deleteContact == nil { + return errors.New("unexpected DeleteContact") + } + return f.deleteContact(ctx, grantID, contactID) +} + +func TestRegisterContactWriteHandlers(t *testing.T) { + clientErr := errors.New("client unavailable") + + tests := []struct { + name string + method string + params string + defaultGrant string + client *fakeContactWriteClient + assert func(*testing.T, *fakeContactWriteClient, rpcTestResponse) + }{ + { + name: "contact.create forwards request", + method: "contact.create", + params: `{"grant_id":"grant-1","given_name":"Ada","surname":"Lovelace","emails":[{"email":"ada@example.com","type":"work"}]}`, + defaultGrant: "default-grant", + client: &fakeContactWriteClient{ + createContact: func(ctx context.Context, grantID string, req *domain.CreateContactRequest) (*domain.Contact, error) { + return &domain.Contact{ID: "contact-1", GivenName: req.GivenName}, nil + }, + }, + assert: func(t *testing.T, client *fakeContactWriteClient, resp rpcTestResponse) { + requireNoRPCError(t, resp) + if client.createGrantID != "grant-1" { + t.Fatalf("createGrantID = %q, want grant-1", client.createGrantID) + } + if client.createReq == nil || client.createReq.GivenName != "Ada" || client.createReq.Surname != "Lovelace" { + t.Fatalf("createReq = %#v, want Ada Lovelace", client.createReq) + } + if len(client.createReq.Emails) != 1 || client.createReq.Emails[0].Email != "ada@example.com" { + t.Fatalf("createReq.Emails = %#v, want ada@example.com", client.createReq.Emails) + } + + var contact domain.Contact + unmarshalResult(t, resp, &contact) + if contact.ID != "contact-1" || contact.GivenName != "Ada" { + t.Fatalf("contact = %#v, want contact-1 Ada", contact) + } + }, + }, + { + name: "contact.create missing grant returns invalid params", + method: "contact.create", + params: `{"given_name":"Ada"}`, + client: &fakeContactWriteClient{}, + assert: func(t *testing.T, client *fakeContactWriteClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + if client.createReq != nil { + t.Fatalf("createReq = %#v, want nil", client.createReq) + } + }, + }, + { + name: "contact.update forwards embedded fields", + method: "contact.update", + params: `{"contact_id":"contact-1","given_name":"Grace"}`, + defaultGrant: "default-grant", + client: &fakeContactWriteClient{ + updateContact: func(ctx context.Context, grantID, contactID string, req *domain.UpdateContactRequest) (*domain.Contact, error) { + return &domain.Contact{ID: contactID, GivenName: *req.GivenName}, nil + }, + }, + assert: func(t *testing.T, client *fakeContactWriteClient, resp rpcTestResponse) { + requireNoRPCError(t, resp) + if client.updateGrantID != "default-grant" || client.updateID != "contact-1" { + t.Fatalf("update args = %q %q, want default-grant contact-1", client.updateGrantID, client.updateID) + } + if client.updateReq == nil || client.updateReq.GivenName == nil || *client.updateReq.GivenName != "Grace" { + t.Fatalf("updateReq = %#v, want given_name Grace", client.updateReq) + } + }, + }, + { + name: "contact.update missing contact_id returns invalid params", + method: "contact.update", + params: `{"given_name":"Grace"}`, + defaultGrant: "default-grant", + client: &fakeContactWriteClient{}, + assert: func(t *testing.T, client *fakeContactWriteClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + if client.updateReq != nil { + t.Fatalf("updateReq = %#v, want nil", client.updateReq) + } + }, + }, + { + name: "contact.delete returns deleted", + method: "contact.delete", + params: `{"grant_id":"grant-1","contact_id":"contact-1"}`, + defaultGrant: "default-grant", + client: &fakeContactWriteClient{ + deleteContact: func(ctx context.Context, grantID, contactID string) error { + return nil + }, + }, + assert: func(t *testing.T, client *fakeContactWriteClient, resp rpcTestResponse) { + requireNoRPCError(t, resp) + if client.deleteGrantID != "grant-1" || client.deleteID != "contact-1" { + t.Fatalf("delete args = %q %q, want grant-1 contact-1", client.deleteGrantID, client.deleteID) + } + + var result deletedResult + unmarshalResult(t, resp, &result) + if !result.Deleted { + t.Fatal("deleted = false, want true") + } + }, + }, + { + name: "contact.delete missing contact_id returns invalid params", + method: "contact.delete", + params: `{}`, + defaultGrant: "default-grant", + client: &fakeContactWriteClient{}, + assert: func(t *testing.T, client *fakeContactWriteClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + if client.deleteID != "" { + t.Fatalf("deleteID = %q, want empty", client.deleteID) + } + }, + }, + { + name: "client error maps to internal error", + method: "contact.create", + params: `{"given_name":"Ada"}`, + defaultGrant: "default-grant", + client: &fakeContactWriteClient{ + createContact: func(ctx context.Context, grantID string, req *domain.CreateContactRequest) (*domain.Contact, error) { + return nil, clientErr + }, + }, + assert: func(t *testing.T, client *fakeContactWriteClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InternalError) + if client.createGrantID != "default-grant" { + t.Fatalf("createGrantID = %q, want default-grant", client.createGrantID) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewDispatcher() + RegisterContactWriteHandlers(d, tt.client, tt.defaultGrant) + + resp := dispatchContactRequest(t, d, tt.method, tt.params) + tt.assert(t, tt.client, resp) + }) + } +} diff --git a/internal/adapters/rpcserver/handlers_contacts.go b/internal/adapters/rpcserver/handlers_contacts.go new file mode 100644 index 0000000..159702e --- /dev/null +++ b/internal/adapters/rpcserver/handlers_contacts.go @@ -0,0 +1,76 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type contactListParams struct { + GrantID string `json:"grant_id,omitempty"` + Limit int `json:"limit,omitempty"` + PageToken string `json:"page_token,omitempty"` +} + +type contactListResult struct { + Contacts []domain.Contact `json:"contacts"` + NextCursor string `json:"next_cursor"` + HasMore bool `json:"has_more"` +} + +type contactGetParams struct { + GrantID string `json:"grant_id,omitempty"` + ContactID string `json:"contact_id"` +} + +func RegisterContactHandlers(d *Dispatcher, client ports.ContactClient, defaultGrant string) { + d.Register("contact.list", func(ctx context.Context, params json.RawMessage) (any, error) { + var p contactListParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + resp, err := client.GetContactsWithCursor(ctx, grantID, &domain.ContactQueryParams{ + Limit: p.Limit, + PageToken: p.PageToken, + }) + if err != nil { + return nil, fmt.Errorf("contact.list: %w", err) + } + + return contactListResult{ + Contacts: resp.Data, + NextCursor: resp.Pagination.NextCursor, + HasMore: resp.Pagination.HasMore, + }, nil + }) + + d.Register("contact.get", func(ctx context.Context, params json.RawMessage) (any, error) { + var p contactGetParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.ContactID == "" { + return nil, NewRPCError(InvalidParams, "contact_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + contact, err := client.GetContact(ctx, grantID, p.ContactID) + if err != nil { + return nil, fmt.Errorf("contact.get: %w", err) + } + return contact, nil + }) +} diff --git a/internal/adapters/rpcserver/handlers_contacts_test.go b/internal/adapters/rpcserver/handlers_contacts_test.go new file mode 100644 index 0000000..6ea0a15 --- /dev/null +++ b/internal/adapters/rpcserver/handlers_contacts_test.go @@ -0,0 +1,214 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type fakeContactClient struct { + ports.ContactClient + + getContactsWithCursor func(context.Context, string, *domain.ContactQueryParams) (*domain.ContactListResponse, error) + getContact func(context.Context, string, string) (*domain.Contact, error) + contactGrantIDs []string + contactParams []domain.ContactQueryParams +} + +func (f *fakeContactClient) GetContactsWithCursor(ctx context.Context, grantID string, params *domain.ContactQueryParams) (*domain.ContactListResponse, error) { + f.contactGrantIDs = append(f.contactGrantIDs, grantID) + if params != nil { + f.contactParams = append(f.contactParams, *params) + } + if f.getContactsWithCursor == nil { + return nil, errors.New("unexpected GetContactsWithCursor") + } + return f.getContactsWithCursor(ctx, grantID, params) +} + +func (f *fakeContactClient) GetContact(ctx context.Context, grantID, contactID string) (*domain.Contact, error) { + if f.getContact == nil { + return nil, errors.New("unexpected GetContact") + } + return f.getContact(ctx, grantID, contactID) +} + +func TestRegisterContactHandlers(t *testing.T) { + clientErr := errors.New("client unavailable") + + tests := []struct { + name string + method string + params string + defaultGrant string + client *fakeContactClient + assert func(*testing.T, rpcTestResponse) + }{ + { + name: "contact.list returns contacts and next cursor", + method: "contact.list", + params: `{"limit":2}`, + defaultGrant: "default-grant", + client: &fakeContactClient{ + getContactsWithCursor: func(ctx context.Context, grantID string, params *domain.ContactQueryParams) (*domain.ContactListResponse, error) { + if grantID != "default-grant" { + t.Fatalf("grantID = %q, want %q", grantID, "default-grant") + } + return &domain.ContactListResponse{ + Data: []domain.Contact{ + {ID: "contact-1", GivenName: "Ada"}, + {ID: "contact-2", GivenName: "Grace"}, + }, + Pagination: domain.Pagination{NextCursor: "cursor-2", HasMore: true}, + }, nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + + var result struct { + Contacts []domain.Contact `json:"contacts"` + NextCursor string `json:"next_cursor"` + HasMore bool `json:"has_more"` + } + unmarshalResult(t, resp, &result) + if len(result.Contacts) != 2 || result.Contacts[0].ID != "contact-1" || result.Contacts[1].ID != "contact-2" { + t.Fatalf("contacts = %#v, want contact-1 and contact-2", result.Contacts) + } + if result.NextCursor != "cursor-2" { + t.Fatalf("next_cursor = %q, want %q", result.NextCursor, "cursor-2") + } + if !result.HasMore { + t.Fatal("has_more = false, want true") + } + }, + }, + { + name: "contact.list forwards query params and request grant", + method: "contact.list", + params: `{"grant_id":"request-grant","limit":25,"page_token":"cursor-1"}`, + defaultGrant: "default-grant", + client: &fakeContactClient{ + getContactsWithCursor: func(ctx context.Context, grantID string, params *domain.ContactQueryParams) (*domain.ContactListResponse, error) { + if grantID != "request-grant" { + t.Fatalf("grantID = %q, want %q", grantID, "request-grant") + } + if params.Limit != 25 { + t.Fatalf("Limit = %d, want 25", params.Limit) + } + if params.PageToken != "cursor-1" { + t.Fatalf("PageToken = %q, want %q", params.PageToken, "cursor-1") + } + return &domain.ContactListResponse{}, nil + }, + }, + assert: requireNoRPCError, + }, + { + name: "contact.list missing grant returns invalid params", + method: "contact.list", + params: `{}`, + client: &fakeContactClient{}, + assert: func(t *testing.T, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + }, + }, + { + name: "contact.list malformed params returns invalid params", + method: "contact.list", + params: `{"limit":"nope"}`, + client: &fakeContactClient{}, + assert: func(t *testing.T, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + }, + }, + { + name: "contact.get with contact_id returns the contact", + method: "contact.get", + params: `{"grant_id":"grant-1","contact_id":"contact-1"}`, + client: &fakeContactClient{ + getContact: func(ctx context.Context, grantID, contactID string) (*domain.Contact, error) { + if grantID != "grant-1" { + t.Fatalf("grantID = %q, want %q", grantID, "grant-1") + } + if contactID != "contact-1" { + t.Fatalf("contactID = %q, want %q", contactID, "contact-1") + } + return &domain.Contact{ID: "contact-1", GivenName: "Ada"}, nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + + var contact domain.Contact + unmarshalResult(t, resp, &contact) + if contact.ID != "contact-1" || contact.GivenName != "Ada" { + t.Fatalf("contact = %#v, want contact-1 Ada", contact) + } + }, + }, + { + name: "contact.get missing contact_id returns invalid params", + method: "contact.get", + params: `{"grant_id":"grant-1"}`, + defaultGrant: "grant-1", + client: &fakeContactClient{}, + assert: func(t *testing.T, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + }, + }, + { + name: "client error maps to internal error", + method: "contact.get", + params: `{"contact_id":"contact-1"}`, + defaultGrant: "default-grant", + client: &fakeContactClient{ + getContact: func(ctx context.Context, grantID, contactID string) (*domain.Contact, error) { + if grantID != "default-grant" { + t.Fatalf("grantID = %q, want %q", grantID, "default-grant") + } + if contactID != "contact-1" { + t.Fatalf("contactID = %q, want %q", contactID, "contact-1") + } + return nil, clientErr + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InternalError) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewDispatcher() + RegisterContactHandlers(d, tt.client, tt.defaultGrant) + + resp := dispatchContactRequest(t, d, tt.method, tt.params) + tt.assert(t, resp) + }) + } +} + +func dispatchContactRequest(t *testing.T, d *Dispatcher, method, params string) rpcTestResponse { + t.Helper() + + raw := []byte(`{"jsonrpc":"2.0","id":1,"method":"` + method + `","params":` + params + `}`) + got := d.Dispatch(context.Background(), raw) + if got == nil { + t.Fatal("Dispatch() = nil, want response") + } + + var resp rpcTestResponse + if err := json.Unmarshal(got, &resp); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + if resp.JSONRPC != "2.0" { + t.Fatalf("JSONRPC = %q, want %q", resp.JSONRPC, "2.0") + } + return resp +} diff --git a/internal/adapters/rpcserver/handlers_draft.go b/internal/adapters/rpcserver/handlers_draft.go new file mode 100644 index 0000000..b9a7c89 --- /dev/null +++ b/internal/adapters/rpcserver/handlers_draft.go @@ -0,0 +1,61 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/nylas/cli/internal/ports" +) + +type draftListParams struct { + GrantID string `json:"grant_id,omitempty"` + Limit int `json:"limit,omitempty"` +} + +type draftGetParams struct { + GrantID string `json:"grant_id,omitempty"` + DraftID string `json:"draft_id"` +} + +func RegisterDraftHandlers(d *Dispatcher, client ports.MessageClient, defaultGrant string) { + d.Register("draft.list", func(ctx context.Context, params json.RawMessage) (any, error) { + var p draftListParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + drafts, err := client.GetDrafts(ctx, grantID, p.Limit) + if err != nil { + return nil, fmt.Errorf("draft.list: %w", err) + } + + return map[string]interface{}{"drafts": drafts}, nil + }) + + d.Register("draft.get", func(ctx context.Context, params json.RawMessage) (any, error) { + var p draftGetParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.DraftID == "" { + return nil, NewRPCError(InvalidParams, "draft_id is required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + draft, err := client.GetDraft(ctx, grantID, p.DraftID) + if err != nil { + return nil, fmt.Errorf("draft.get: %w", err) + } + return draft, nil + }) +} diff --git a/internal/adapters/rpcserver/handlers_draft_test.go b/internal/adapters/rpcserver/handlers_draft_test.go new file mode 100644 index 0000000..702e269 --- /dev/null +++ b/internal/adapters/rpcserver/handlers_draft_test.go @@ -0,0 +1,198 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type fakeDraftClient struct { + ports.MessageClient + + getDrafts func(context.Context, string, int) ([]domain.Draft, error) + getDraft func(context.Context, string, string) (*domain.Draft, error) +} + +func (f *fakeDraftClient) GetDrafts(ctx context.Context, grantID string, limit int) ([]domain.Draft, error) { + if f.getDrafts == nil { + return nil, errors.New("unexpected GetDrafts") + } + return f.getDrafts(ctx, grantID, limit) +} + +func (f *fakeDraftClient) GetDraft(ctx context.Context, grantID, draftID string) (*domain.Draft, error) { + if f.getDraft == nil { + return nil, errors.New("unexpected GetDraft") + } + return f.getDraft(ctx, grantID, draftID) +} + +type draftRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + ID json.RawMessage `json:"id,omitempty"` + Result json.RawMessage `json:"result,omitempty"` + Error *RPCError `json:"error,omitempty"` +} + +func TestRegisterDraftHandlers(t *testing.T) { + clientErr := errors.New("client unavailable") + + tests := []struct { + name string + method string + params string + defaultGrant string + client *fakeDraftClient + assert func(*testing.T, draftRPCResponse) + }{ + { + name: "draft.list returns drafts", + method: "draft.list", + params: `{"limit":2}`, + defaultGrant: "default-grant", + client: &fakeDraftClient{ + getDrafts: func(ctx context.Context, grantID string, limit int) ([]domain.Draft, error) { + if grantID != "default-grant" { + t.Fatalf("grantID = %q, want %q", grantID, "default-grant") + } + if limit != 2 { + t.Fatalf("limit = %d, want 2", limit) + } + return []domain.Draft{ + {ID: "draft-1", Subject: "Hello"}, + {ID: "draft-2", Subject: "World"}, + }, nil + }, + }, + assert: func(t *testing.T, resp draftRPCResponse) { + requireNoDraftRPCError(t, resp) + + var result struct { + Drafts []domain.Draft `json:"drafts"` + } + unmarshalDraftResult(t, resp, &result) + if len(result.Drafts) != 2 || result.Drafts[0].ID != "draft-1" || result.Drafts[1].ID != "draft-2" { + t.Fatalf("drafts = %#v, want draft-1 and draft-2", result.Drafts) + } + }, + }, + { + name: "draft.list missing grant returns invalid params", + method: "draft.list", + params: `{"grant_id":""}`, + client: &fakeDraftClient{}, + assert: func(t *testing.T, resp draftRPCResponse) { + requireDraftRPCErrorCode(t, resp, InvalidParams) + }, + }, + { + name: "draft.get missing draft_id returns invalid params", + method: "draft.get", + params: `{"grant_id":"grant-1"}`, + defaultGrant: "grant-1", + client: &fakeDraftClient{}, + assert: func(t *testing.T, resp draftRPCResponse) { + requireDraftRPCErrorCode(t, resp, InvalidParams) + }, + }, + { + name: "draft.get with draft_id returns the draft", + method: "draft.get", + params: `{"grant_id":"grant-1","draft_id":"draft-1"}`, + client: &fakeDraftClient{ + getDraft: func(ctx context.Context, grantID, draftID string) (*domain.Draft, error) { + if grantID != "grant-1" { + t.Fatalf("grantID = %q, want %q", grantID, "grant-1") + } + if draftID != "draft-1" { + t.Fatalf("draftID = %q, want %q", draftID, "draft-1") + } + return &domain.Draft{ID: "draft-1", Subject: "Hello"}, nil + }, + }, + assert: func(t *testing.T, resp draftRPCResponse) { + requireNoDraftRPCError(t, resp) + + var draft domain.Draft + unmarshalDraftResult(t, resp, &draft) + if draft.ID != "draft-1" || draft.Subject != "Hello" { + t.Fatalf("draft = %#v, want draft-1 Hello", draft) + } + }, + }, + { + name: "client error maps to internal error", + method: "draft.get", + params: `{"draft_id":"draft-1"}`, + defaultGrant: "default-grant", + client: &fakeDraftClient{ + getDraft: func(ctx context.Context, grantID, draftID string) (*domain.Draft, error) { + return nil, clientErr + }, + }, + assert: func(t *testing.T, resp draftRPCResponse) { + requireDraftRPCErrorCode(t, resp, InternalError) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewDispatcher() + RegisterDraftHandlers(d, tt.client, tt.defaultGrant) + + resp := dispatchDraftRequest(t, d, tt.method, tt.params) + tt.assert(t, resp) + }) + } +} + +func dispatchDraftRequest(t *testing.T, d *Dispatcher, method, params string) draftRPCResponse { + t.Helper() + + raw := []byte(`{"jsonrpc":"2.0","id":1,"method":"` + method + `","params":` + params + `}`) + got := d.Dispatch(context.Background(), raw) + if got == nil { + t.Fatal("Dispatch() = nil, want response") + } + + var resp draftRPCResponse + if err := json.Unmarshal(got, &resp); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + if resp.JSONRPC != "2.0" { + t.Fatalf("JSONRPC = %q, want %q", resp.JSONRPC, "2.0") + } + return resp +} + +func requireNoDraftRPCError(t *testing.T, resp draftRPCResponse) { + t.Helper() + + if resp.Error != nil { + t.Fatalf("Error = %#v, want nil", resp.Error) + } +} + +func requireDraftRPCErrorCode(t *testing.T, resp draftRPCResponse, want int) { + t.Helper() + + if resp.Error == nil { + t.Fatal("Error = nil, want RPC error") + } + if resp.Error.Code != want { + t.Fatalf("Error.Code = %d, want %d", resp.Error.Code, want) + } +} + +func unmarshalDraftResult(t *testing.T, resp draftRPCResponse, dest any) { + t.Helper() + + if err := json.Unmarshal(resp.Result, dest); err != nil { + t.Fatalf("unmarshal result: %v", err) + } +} diff --git a/internal/adapters/rpcserver/handlers_email.go b/internal/adapters/rpcserver/handlers_email.go new file mode 100644 index 0000000..9379f25 --- /dev/null +++ b/internal/adapters/rpcserver/handlers_email.go @@ -0,0 +1,78 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type emailListParams struct { + GrantID string `json:"grant_id,omitempty"` + Limit int `json:"limit,omitempty"` + PageToken string `json:"page_token,omitempty"` + ReceivedAfter int64 `json:"received_after,omitempty"` +} + +type emailListResult struct { + Messages []domain.Message `json:"messages"` + NextCursor string `json:"next_cursor"` + HasMore bool `json:"has_more"` +} + +type emailGetParams struct { + GrantID string `json:"grant_id,omitempty"` + MessageID string `json:"message_id"` +} + +func RegisterEmailHandlers(d *Dispatcher, client ports.MessageClient, defaultGrant string) { + d.Register("email.list", func(ctx context.Context, params json.RawMessage) (any, error) { + var p emailListParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + resp, err := client.GetMessagesWithCursor(ctx, grantID, &domain.MessageQueryParams{ + Limit: p.Limit, + PageToken: p.PageToken, + ReceivedAfter: p.ReceivedAfter, + }) + if err != nil { + return nil, fmt.Errorf("email.list: %w", err) + } + + return emailListResult{ + Messages: resp.Data, + NextCursor: resp.Pagination.NextCursor, + HasMore: resp.Pagination.HasMore, + }, nil + }) + + d.Register("email.get", func(ctx context.Context, params json.RawMessage) (any, error) { + var p emailGetParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.MessageID == "" { + return nil, NewRPCError(InvalidParams, "message_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + msg, err := client.GetMessage(ctx, grantID, p.MessageID) + if err != nil { + return nil, fmt.Errorf("email.get: %w", err) + } + return msg, nil + }) +} diff --git a/internal/adapters/rpcserver/handlers_email_ext.go b/internal/adapters/rpcserver/handlers_email_ext.go new file mode 100644 index 0000000..eab7c2b --- /dev/null +++ b/internal/adapters/rpcserver/handlers_email_ext.go @@ -0,0 +1,505 @@ +package rpcserver + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +// maxAttachmentDownloadBytes caps email.attachment.download: the bytes are +// base64-encoded into one JSON-RPC response, so the whole attachment is held in +// memory. 30 MiB comfortably covers provider attachment limits. +const maxAttachmentDownloadBytes = 30 << 20 + +type emailGrantParams struct { + GrantID string `json:"grant_id,omitempty"` +} + +type folderListResult struct { + Folders []domain.Folder `json:"folders"` +} + +type folderGetParams struct { + GrantID string `json:"grant_id,omitempty"` + FolderID string `json:"folder_id"` +} + +type folderCreateParams struct { + GrantID string `json:"grant_id,omitempty"` + domain.CreateFolderRequest +} + +type folderUpdateParams struct { + GrantID string `json:"grant_id,omitempty"` + FolderID string `json:"folder_id"` + domain.UpdateFolderRequest +} + +type folderDeleteParams struct { + GrantID string `json:"grant_id,omitempty"` + FolderID string `json:"folder_id"` +} + +type attachmentListParams struct { + GrantID string `json:"grant_id,omitempty"` + MessageID string `json:"message_id"` +} + +type attachmentListResult struct { + Attachments []domain.Attachment `json:"attachments"` +} + +type attachmentGetParams struct { + GrantID string `json:"grant_id,omitempty"` + MessageID string `json:"message_id"` + AttachmentID string `json:"attachment_id"` +} + +type attachmentDownloadResult struct { + Content string `json:"content"` // base64-encoded attachment bytes + Size int `json:"size"` +} + +type signatureListResult struct { + Signatures []domain.Signature `json:"signatures"` +} + +type signatureGetParams struct { + GrantID string `json:"grant_id,omitempty"` + SignatureID string `json:"signature_id"` +} + +type signatureCreateParams struct { + GrantID string `json:"grant_id,omitempty"` + domain.CreateSignatureRequest +} + +type signatureUpdateParams struct { + GrantID string `json:"grant_id,omitempty"` + SignatureID string `json:"signature_id"` + domain.UpdateSignatureRequest +} + +type signatureDeleteParams struct { + GrantID string `json:"grant_id,omitempty"` + SignatureID string `json:"signature_id"` +} + +type scheduledListResult struct { + Scheduled []domain.ScheduledMessage `json:"scheduled"` +} + +type scheduledGetParams struct { + GrantID string `json:"grant_id,omitempty"` + ScheduleID string `json:"schedule_id"` +} + +// cleanParams uses message_ids (plural) for the RPC contract; the embedded +// domain request tags the same slice message_id, which is confusing over the +// wire, so the IDs are accepted explicitly and copied into the request. +type cleanParams struct { + GrantID string `json:"grant_id,omitempty"` + MessageIDs []string `json:"message_ids"` + IgnoreLinks *bool `json:"ignore_links,omitempty"` + IgnoreImages *bool `json:"ignore_images,omitempty"` + IgnoreTables *bool `json:"ignore_tables,omitempty"` + ImagesAsMarkdown *bool `json:"images_as_markdown,omitempty"` + RemoveConclusionPhrases *bool `json:"remove_conclusion_phrases,omitempty"` +} + +type cleanResult struct { + Messages []domain.CleanedMessage `json:"messages"` +} + +type cancelledResult struct { + Cancelled bool `json:"cancelled"` +} + +// RegisterEmailExtHandlers registers folder, attachment, signature, scheduled +// message, and message-clean methods. +func RegisterEmailExtHandlers(d *Dispatcher, client ports.MessageClient, defaultGrant string) { + registerEmailFolderHandlers(d, client, defaultGrant) + registerEmailAttachmentHandlers(d, client, defaultGrant) + registerEmailSignatureHandlers(d, client, defaultGrant) + registerEmailScheduledHandlers(d, client, defaultGrant) + + d.Register("email.clean", func(ctx context.Context, params json.RawMessage) (any, error) { + var p cleanParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + if len(p.MessageIDs) == 0 { + return nil, NewRPCError(InvalidParams, "message_ids required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + messages, err := client.CleanMessages(ctx, grantID, &domain.CleanMessagesRequest{ + MessageIDs: p.MessageIDs, + IgnoreLinks: p.IgnoreLinks, + IgnoreImages: p.IgnoreImages, + IgnoreTables: p.IgnoreTables, + ImagesAsMarkdown: p.ImagesAsMarkdown, + RemoveConclusionPhrases: p.RemoveConclusionPhrases, + }) + if err != nil { + return nil, fmt.Errorf("email.clean: %w", err) + } + return cleanResult{Messages: messages}, nil + }) +} + +func registerEmailFolderHandlers(d *Dispatcher, client ports.MessageClient, defaultGrant string) { + d.Register("email.folder.list", func(ctx context.Context, params json.RawMessage) (any, error) { + var p emailGrantParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + folders, err := client.GetFolders(ctx, grantID) + if err != nil { + return nil, fmt.Errorf("email.folder.list: %w", err) + } + return folderListResult{Folders: folders}, nil + }) + + d.Register("email.folder.get", func(ctx context.Context, params json.RawMessage) (any, error) { + var p folderGetParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.FolderID == "" { + return nil, NewRPCError(InvalidParams, "folder_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + folder, err := client.GetFolder(ctx, grantID, p.FolderID) + if err != nil { + return nil, fmt.Errorf("email.folder.get: %w", err) + } + return folder, nil + }) + + d.Register("email.folder.create", func(ctx context.Context, params json.RawMessage) (any, error) { + var p folderCreateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + folder, err := client.CreateFolder(ctx, grantID, &p.CreateFolderRequest) + if err != nil { + return nil, fmt.Errorf("email.folder.create: %w", err) + } + return folder, nil + }) + + d.Register("email.folder.update", func(ctx context.Context, params json.RawMessage) (any, error) { + var p folderUpdateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.FolderID == "" { + return nil, NewRPCError(InvalidParams, "folder_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + folder, err := client.UpdateFolder(ctx, grantID, p.FolderID, &p.UpdateFolderRequest) + if err != nil { + return nil, fmt.Errorf("email.folder.update: %w", err) + } + return folder, nil + }) + + d.Register("email.folder.delete", func(ctx context.Context, params json.RawMessage) (any, error) { + var p folderDeleteParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.FolderID == "" { + return nil, NewRPCError(InvalidParams, "folder_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + if err := client.DeleteFolder(ctx, grantID, p.FolderID); err != nil { + return nil, fmt.Errorf("email.folder.delete: %w", err) + } + return deletedResult{Deleted: true}, nil + }) +} + +func registerEmailAttachmentHandlers(d *Dispatcher, client ports.MessageClient, defaultGrant string) { + d.Register("email.attachment.list", func(ctx context.Context, params json.RawMessage) (any, error) { + var p attachmentListParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.MessageID == "" { + return nil, NewRPCError(InvalidParams, "message_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + attachments, err := client.ListAttachments(ctx, grantID, p.MessageID) + if err != nil { + return nil, fmt.Errorf("email.attachment.list: %w", err) + } + return attachmentListResult{Attachments: attachments}, nil + }) + + d.Register("email.attachment.get", func(ctx context.Context, params json.RawMessage) (any, error) { + var p attachmentGetParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.MessageID == "" { + return nil, NewRPCError(InvalidParams, "message_id required", nil) + } + if p.AttachmentID == "" { + return nil, NewRPCError(InvalidParams, "attachment_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + attachment, err := client.GetAttachment(ctx, grantID, p.MessageID, p.AttachmentID) + if err != nil { + return nil, fmt.Errorf("email.attachment.get: %w", err) + } + return attachment, nil + }) + + d.Register("email.attachment.download", func(ctx context.Context, params json.RawMessage) (any, error) { + var p attachmentGetParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.MessageID == "" { + return nil, NewRPCError(InvalidParams, "message_id required", nil) + } + if p.AttachmentID == "" { + return nil, NewRPCError(InvalidParams, "attachment_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + body, err := client.DownloadAttachment(ctx, grantID, p.MessageID, p.AttachmentID) + if err != nil { + return nil, fmt.Errorf("email.attachment.download: %w", err) + } + defer func() { _ = body.Close() }() + + // Cap the in-memory buffer: the whole attachment is base64-encoded into a + // single JSON-RPC response, so an oversized attachment would balloon heap. + data, err := io.ReadAll(io.LimitReader(body, maxAttachmentDownloadBytes+1)) + if err != nil { + return nil, fmt.Errorf("email.attachment.download: read body: %w", err) + } + if len(data) > maxAttachmentDownloadBytes { + return nil, NewRPCError(InvalidParams, "attachment exceeds maximum download size", nil) + } + return attachmentDownloadResult{ + Content: base64.StdEncoding.EncodeToString(data), + Size: len(data), + }, nil + }) +} + +func registerEmailSignatureHandlers(d *Dispatcher, client ports.MessageClient, defaultGrant string) { + d.Register("email.signature.list", func(ctx context.Context, params json.RawMessage) (any, error) { + var p emailGrantParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + signatures, err := client.GetSignatures(ctx, grantID) + if err != nil { + return nil, fmt.Errorf("email.signature.list: %w", err) + } + return signatureListResult{Signatures: signatures}, nil + }) + + d.Register("email.signature.get", func(ctx context.Context, params json.RawMessage) (any, error) { + var p signatureGetParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.SignatureID == "" { + return nil, NewRPCError(InvalidParams, "signature_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + signature, err := client.GetSignature(ctx, grantID, p.SignatureID) + if err != nil { + return nil, fmt.Errorf("email.signature.get: %w", err) + } + return signature, nil + }) + + d.Register("email.signature.create", func(ctx context.Context, params json.RawMessage) (any, error) { + var p signatureCreateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + signature, err := client.CreateSignature(ctx, grantID, &p.CreateSignatureRequest) + if err != nil { + return nil, fmt.Errorf("email.signature.create: %w", err) + } + return signature, nil + }) + + d.Register("email.signature.update", func(ctx context.Context, params json.RawMessage) (any, error) { + var p signatureUpdateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.SignatureID == "" { + return nil, NewRPCError(InvalidParams, "signature_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + signature, err := client.UpdateSignature(ctx, grantID, p.SignatureID, &p.UpdateSignatureRequest) + if err != nil { + return nil, fmt.Errorf("email.signature.update: %w", err) + } + return signature, nil + }) + + d.Register("email.signature.delete", func(ctx context.Context, params json.RawMessage) (any, error) { + var p signatureDeleteParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.SignatureID == "" { + return nil, NewRPCError(InvalidParams, "signature_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + if err := client.DeleteSignature(ctx, grantID, p.SignatureID); err != nil { + return nil, fmt.Errorf("email.signature.delete: %w", err) + } + return deletedResult{Deleted: true}, nil + }) +} + +func registerEmailScheduledHandlers(d *Dispatcher, client ports.MessageClient, defaultGrant string) { + d.Register("email.scheduled.list", func(ctx context.Context, params json.RawMessage) (any, error) { + var p emailGrantParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + scheduled, err := client.ListScheduledMessages(ctx, grantID) + if err != nil { + return nil, fmt.Errorf("email.scheduled.list: %w", err) + } + return scheduledListResult{Scheduled: scheduled}, nil + }) + + d.Register("email.scheduled.get", func(ctx context.Context, params json.RawMessage) (any, error) { + var p scheduledGetParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.ScheduleID == "" { + return nil, NewRPCError(InvalidParams, "schedule_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + scheduled, err := client.GetScheduledMessage(ctx, grantID, p.ScheduleID) + if err != nil { + return nil, fmt.Errorf("email.scheduled.get: %w", err) + } + return scheduled, nil + }) + + d.Register("email.scheduled.cancel", func(ctx context.Context, params json.RawMessage) (any, error) { + var p scheduledGetParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.ScheduleID == "" { + return nil, NewRPCError(InvalidParams, "schedule_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + if err := client.CancelScheduledMessage(ctx, grantID, p.ScheduleID); err != nil { + return nil, fmt.Errorf("email.scheduled.cancel: %w", err) + } + return cancelledResult{Cancelled: true}, nil + }) +} diff --git a/internal/adapters/rpcserver/handlers_email_ext_test.go b/internal/adapters/rpcserver/handlers_email_ext_test.go new file mode 100644 index 0000000..778c515 --- /dev/null +++ b/internal/adapters/rpcserver/handlers_email_ext_test.go @@ -0,0 +1,557 @@ +package rpcserver + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "io" + "strings" + "testing" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type fakeEmailExtClient struct { + ports.MessageClient + + getFolders func(context.Context, string) ([]domain.Folder, error) + getFolder func(context.Context, string, string) (*domain.Folder, error) + createFolder func(context.Context, string, *domain.CreateFolderRequest) (*domain.Folder, error) + updateFolder func(context.Context, string, string, *domain.UpdateFolderRequest) (*domain.Folder, error) + deleteFolder func(context.Context, string, string) error + listAttachments func(context.Context, string, string) ([]domain.Attachment, error) + getAttachment func(context.Context, string, string, string) (*domain.Attachment, error) + downloadAttachment func(context.Context, string, string, string) (io.ReadCloser, error) + getSignatures func(context.Context, string) ([]domain.Signature, error) + getSignature func(context.Context, string, string) (*domain.Signature, error) + createSignature func(context.Context, string, *domain.CreateSignatureRequest) (*domain.Signature, error) + updateSignature func(context.Context, string, string, *domain.UpdateSignatureRequest) (*domain.Signature, error) + deleteSignature func(context.Context, string, string) error + listScheduled func(context.Context, string) ([]domain.ScheduledMessage, error) + getScheduled func(context.Context, string, string) (*domain.ScheduledMessage, error) + cancelScheduled func(context.Context, string, string) error + cleanMessages func(context.Context, string, *domain.CleanMessagesRequest) ([]domain.CleanedMessage, error) +} + +func (f *fakeEmailExtClient) GetFolders(ctx context.Context, grantID string) ([]domain.Folder, error) { + if f.getFolders == nil { + return nil, errors.New("unexpected GetFolders") + } + return f.getFolders(ctx, grantID) +} + +func (f *fakeEmailExtClient) GetFolder(ctx context.Context, grantID, folderID string) (*domain.Folder, error) { + if f.getFolder == nil { + return nil, errors.New("unexpected GetFolder") + } + return f.getFolder(ctx, grantID, folderID) +} + +func (f *fakeEmailExtClient) CreateFolder(ctx context.Context, grantID string, req *domain.CreateFolderRequest) (*domain.Folder, error) { + if f.createFolder == nil { + return nil, errors.New("unexpected CreateFolder") + } + return f.createFolder(ctx, grantID, req) +} + +func (f *fakeEmailExtClient) UpdateFolder(ctx context.Context, grantID, folderID string, req *domain.UpdateFolderRequest) (*domain.Folder, error) { + if f.updateFolder == nil { + return nil, errors.New("unexpected UpdateFolder") + } + return f.updateFolder(ctx, grantID, folderID, req) +} + +func (f *fakeEmailExtClient) DeleteFolder(ctx context.Context, grantID, folderID string) error { + if f.deleteFolder == nil { + return errors.New("unexpected DeleteFolder") + } + return f.deleteFolder(ctx, grantID, folderID) +} + +func (f *fakeEmailExtClient) ListAttachments(ctx context.Context, grantID, messageID string) ([]domain.Attachment, error) { + if f.listAttachments == nil { + return nil, errors.New("unexpected ListAttachments") + } + return f.listAttachments(ctx, grantID, messageID) +} + +func (f *fakeEmailExtClient) GetAttachment(ctx context.Context, grantID, messageID, attachmentID string) (*domain.Attachment, error) { + if f.getAttachment == nil { + return nil, errors.New("unexpected GetAttachment") + } + return f.getAttachment(ctx, grantID, messageID, attachmentID) +} + +func (f *fakeEmailExtClient) DownloadAttachment(ctx context.Context, grantID, messageID, attachmentID string) (io.ReadCloser, error) { + if f.downloadAttachment == nil { + return nil, errors.New("unexpected DownloadAttachment") + } + return f.downloadAttachment(ctx, grantID, messageID, attachmentID) +} + +func (f *fakeEmailExtClient) GetSignatures(ctx context.Context, grantID string) ([]domain.Signature, error) { + if f.getSignatures == nil { + return nil, errors.New("unexpected GetSignatures") + } + return f.getSignatures(ctx, grantID) +} + +func (f *fakeEmailExtClient) GetSignature(ctx context.Context, grantID, signatureID string) (*domain.Signature, error) { + if f.getSignature == nil { + return nil, errors.New("unexpected GetSignature") + } + return f.getSignature(ctx, grantID, signatureID) +} + +func (f *fakeEmailExtClient) CreateSignature(ctx context.Context, grantID string, req *domain.CreateSignatureRequest) (*domain.Signature, error) { + if f.createSignature == nil { + return nil, errors.New("unexpected CreateSignature") + } + return f.createSignature(ctx, grantID, req) +} + +func (f *fakeEmailExtClient) UpdateSignature(ctx context.Context, grantID, signatureID string, req *domain.UpdateSignatureRequest) (*domain.Signature, error) { + if f.updateSignature == nil { + return nil, errors.New("unexpected UpdateSignature") + } + return f.updateSignature(ctx, grantID, signatureID, req) +} + +func (f *fakeEmailExtClient) DeleteSignature(ctx context.Context, grantID, signatureID string) error { + if f.deleteSignature == nil { + return errors.New("unexpected DeleteSignature") + } + return f.deleteSignature(ctx, grantID, signatureID) +} + +func (f *fakeEmailExtClient) ListScheduledMessages(ctx context.Context, grantID string) ([]domain.ScheduledMessage, error) { + if f.listScheduled == nil { + return nil, errors.New("unexpected ListScheduledMessages") + } + return f.listScheduled(ctx, grantID) +} + +func (f *fakeEmailExtClient) GetScheduledMessage(ctx context.Context, grantID, scheduleID string) (*domain.ScheduledMessage, error) { + if f.getScheduled == nil { + return nil, errors.New("unexpected GetScheduledMessage") + } + return f.getScheduled(ctx, grantID, scheduleID) +} + +func (f *fakeEmailExtClient) CancelScheduledMessage(ctx context.Context, grantID, scheduleID string) error { + if f.cancelScheduled == nil { + return errors.New("unexpected CancelScheduledMessage") + } + return f.cancelScheduled(ctx, grantID, scheduleID) +} + +func (f *fakeEmailExtClient) CleanMessages(ctx context.Context, grantID string, req *domain.CleanMessagesRequest) ([]domain.CleanedMessage, error) { + if f.cleanMessages == nil { + return nil, errors.New("unexpected CleanMessages") + } + return f.cleanMessages(ctx, grantID, req) +} + +type trackingReadCloser struct { + io.Reader + closed bool +} + +func (t *trackingReadCloser) Close() error { + t.closed = true + return nil +} + +func TestRegisterEmailExtHandlers(t *testing.T) { + clientErr := errors.New("client unavailable") + + tests := []struct { + name string + method string + params string + defaultGrant string + client *fakeEmailExtClient + assert func(*testing.T, rpcTestResponse) + }{ + { + name: "email.folder.list returns folders", + method: "email.folder.list", + params: `{}`, + defaultGrant: "default-grant", + client: &fakeEmailExtClient{ + getFolders: func(_ context.Context, grantID string) ([]domain.Folder, error) { + if grantID != "default-grant" { + t.Fatalf("grantID = %q, want default-grant", grantID) + } + return []domain.Folder{{ID: "fld-1", Name: "Inbox"}}, nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + var result folderListResult + unmarshalResult(t, resp, &result) + if len(result.Folders) != 1 || result.Folders[0].ID != "fld-1" { + t.Fatalf("folders = %+v, want one fld-1", result.Folders) + } + }, + }, + { + name: "email.folder.get missing folder_id", + method: "email.folder.get", + params: `{}`, + defaultGrant: "default-grant", + client: &fakeEmailExtClient{}, + assert: func(t *testing.T, resp rpcTestResponse) { requireRPCErrorCode(t, resp, InvalidParams) }, + }, + { + name: "email.folder.create returns folder", + method: "email.folder.create", + params: `{"name":"Receipts"}`, + defaultGrant: "default-grant", + client: &fakeEmailExtClient{ + createFolder: func(_ context.Context, _ string, req *domain.CreateFolderRequest) (*domain.Folder, error) { + if req.Name != "Receipts" { + t.Fatalf("name = %q, want Receipts", req.Name) + } + return &domain.Folder{ID: "fld-new", Name: req.Name}, nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + var folder domain.Folder + unmarshalResult(t, resp, &folder) + if folder.ID != "fld-new" { + t.Fatalf("folder ID = %q, want fld-new", folder.ID) + } + }, + }, + { + name: "email.folder.update missing folder_id", + method: "email.folder.update", + params: `{"name":"x"}`, + defaultGrant: "default-grant", + client: &fakeEmailExtClient{}, + assert: func(t *testing.T, resp rpcTestResponse) { requireRPCErrorCode(t, resp, InvalidParams) }, + }, + { + name: "email.folder.delete returns deleted", + method: "email.folder.delete", + params: `{"folder_id":"fld-1"}`, + defaultGrant: "default-grant", + client: &fakeEmailExtClient{ + deleteFolder: func(_ context.Context, _, folderID string) error { + if folderID != "fld-1" { + t.Fatalf("folderID = %q, want fld-1", folderID) + } + return nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + var result deletedResult + unmarshalResult(t, resp, &result) + if !result.Deleted { + t.Fatal("deleted = false, want true") + } + }, + }, + { + name: "email.attachment.list missing message_id", + method: "email.attachment.list", + params: `{}`, + defaultGrant: "default-grant", + client: &fakeEmailExtClient{}, + assert: func(t *testing.T, resp rpcTestResponse) { requireRPCErrorCode(t, resp, InvalidParams) }, + }, + { + name: "email.attachment.list returns attachments", + method: "email.attachment.list", + params: `{"message_id":"msg-1"}`, + defaultGrant: "default-grant", + client: &fakeEmailExtClient{ + listAttachments: func(_ context.Context, _, messageID string) ([]domain.Attachment, error) { + if messageID != "msg-1" { + t.Fatalf("messageID = %q, want msg-1", messageID) + } + return []domain.Attachment{{ID: "att-1"}}, nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + var result attachmentListResult + unmarshalResult(t, resp, &result) + if len(result.Attachments) != 1 || result.Attachments[0].ID != "att-1" { + t.Fatalf("attachments = %+v, want one att-1", result.Attachments) + } + }, + }, + { + name: "email.attachment.get missing attachment_id", + method: "email.attachment.get", + params: `{"message_id":"msg-1"}`, + defaultGrant: "default-grant", + client: &fakeEmailExtClient{}, + assert: func(t *testing.T, resp rpcTestResponse) { requireRPCErrorCode(t, resp, InvalidParams) }, + }, + { + name: "email.attachment.download base64-encodes bytes", + method: "email.attachment.download", + params: `{"message_id":"msg-1","attachment_id":"att-1"}`, + defaultGrant: "default-grant", + client: &fakeEmailExtClient{ + downloadAttachment: func(_ context.Context, _, messageID, attachmentID string) (io.ReadCloser, error) { + if messageID != "msg-1" || attachmentID != "att-1" { + t.Fatalf("args = %q/%q, want msg-1/att-1", messageID, attachmentID) + } + return io.NopCloser(strings.NewReader("hello world")), nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + var result attachmentDownloadResult + unmarshalResult(t, resp, &result) + want := base64.StdEncoding.EncodeToString([]byte("hello world")) + if result.Content != want { + t.Fatalf("content = %q, want %q", result.Content, want) + } + if result.Size != len("hello world") { + t.Fatalf("size = %d, want %d", result.Size, len("hello world")) + } + }, + }, + { + name: "email.signature.list returns signatures", + method: "email.signature.list", + params: `{}`, + defaultGrant: "default-grant", + client: &fakeEmailExtClient{ + getSignatures: func(context.Context, string) ([]domain.Signature, error) { + return []domain.Signature{{ID: "sig-1", Name: "Default"}}, nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + var result signatureListResult + unmarshalResult(t, resp, &result) + if len(result.Signatures) != 1 || result.Signatures[0].ID != "sig-1" { + t.Fatalf("signatures = %+v, want one sig-1", result.Signatures) + } + }, + }, + { + name: "email.signature.get missing signature_id", + method: "email.signature.get", + params: `{}`, + defaultGrant: "default-grant", + client: &fakeEmailExtClient{}, + assert: func(t *testing.T, resp rpcTestResponse) { requireRPCErrorCode(t, resp, InvalidParams) }, + }, + { + name: "email.signature.create returns signature", + method: "email.signature.create", + params: `{"name":"Sig"}`, + defaultGrant: "default-grant", + client: &fakeEmailExtClient{ + createSignature: func(context.Context, string, *domain.CreateSignatureRequest) (*domain.Signature, error) { + return &domain.Signature{ID: "sig-new"}, nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + var sig domain.Signature + unmarshalResult(t, resp, &sig) + if sig.ID != "sig-new" { + t.Fatalf("signature ID = %q, want sig-new", sig.ID) + } + }, + }, + { + name: "email.signature.update missing signature_id", + method: "email.signature.update", + params: `{"name":"x"}`, + defaultGrant: "default-grant", + client: &fakeEmailExtClient{}, + assert: func(t *testing.T, resp rpcTestResponse) { requireRPCErrorCode(t, resp, InvalidParams) }, + }, + { + name: "email.signature.delete returns deleted", + method: "email.signature.delete", + params: `{"signature_id":"sig-1"}`, + defaultGrant: "default-grant", + client: &fakeEmailExtClient{ + deleteSignature: func(_ context.Context, _, signatureID string) error { + if signatureID != "sig-1" { + t.Fatalf("signatureID = %q, want sig-1", signatureID) + } + return nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + var result deletedResult + unmarshalResult(t, resp, &result) + if !result.Deleted { + t.Fatal("deleted = false, want true") + } + }, + }, + { + name: "email.scheduled.list returns scheduled", + method: "email.scheduled.list", + params: `{}`, + defaultGrant: "default-grant", + client: &fakeEmailExtClient{ + listScheduled: func(context.Context, string) ([]domain.ScheduledMessage, error) { + return []domain.ScheduledMessage{{ScheduleID: "sch-1", Status: "pending"}}, nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + var result scheduledListResult + unmarshalResult(t, resp, &result) + if len(result.Scheduled) != 1 || result.Scheduled[0].ScheduleID != "sch-1" { + t.Fatalf("scheduled = %+v, want one sch-1", result.Scheduled) + } + }, + }, + { + name: "email.scheduled.get missing schedule_id", + method: "email.scheduled.get", + params: `{}`, + defaultGrant: "default-grant", + client: &fakeEmailExtClient{}, + assert: func(t *testing.T, resp rpcTestResponse) { requireRPCErrorCode(t, resp, InvalidParams) }, + }, + { + name: "email.scheduled.cancel returns canceled", + method: "email.scheduled.cancel", + params: `{"schedule_id":"sch-1"}`, + defaultGrant: "default-grant", + client: &fakeEmailExtClient{ + cancelScheduled: func(_ context.Context, _, scheduleID string) error { + if scheduleID != "sch-1" { + t.Fatalf("scheduleID = %q, want sch-1", scheduleID) + } + return nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + var result cancelledResult + unmarshalResult(t, resp, &result) + if !result.Cancelled { + t.Fatal("cancelled = false, want true") + } + }, + }, + { + name: "email.clean maps message_ids and returns cleaned messages", + method: "email.clean", + params: `{"message_ids":["msg-1","msg-2"]}`, + defaultGrant: "default-grant", + client: &fakeEmailExtClient{ + cleanMessages: func(_ context.Context, _ string, req *domain.CleanMessagesRequest) ([]domain.CleanedMessage, error) { + if len(req.MessageIDs) != 2 || req.MessageIDs[0] != "msg-1" { + t.Fatalf("req.MessageIDs = %v, want [msg-1 msg-2]", req.MessageIDs) + } + return []domain.CleanedMessage{{ID: "msg-1", Conversation: "clean text"}}, nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + var result cleanResult + unmarshalResult(t, resp, &result) + if len(result.Messages) != 1 || result.Messages[0].ID != "msg-1" { + t.Fatalf("messages = %+v, want one msg-1", result.Messages) + } + }, + }, + { + name: "email.clean without message_ids errors", + method: "email.clean", + params: `{}`, + defaultGrant: "default-grant", + client: &fakeEmailExtClient{}, + assert: func(t *testing.T, resp rpcTestResponse) { requireRPCErrorCode(t, resp, InvalidParams) }, + }, + { + name: "email.attachment.download rejects oversized attachment", + method: "email.attachment.download", + params: `{"message_id":"msg-1","attachment_id":"att-big"}`, + defaultGrant: "default-grant", + client: &fakeEmailExtClient{ + downloadAttachment: func(context.Context, string, string, string) (io.ReadCloser, error) { + // One byte past the cap. + return io.NopCloser(strings.NewReader(strings.Repeat("a", maxAttachmentDownloadBytes+1))), nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { requireRPCErrorCode(t, resp, InvalidParams) }, + }, + { + name: "client error surfaces as internal error", + method: "email.folder.list", + params: `{}`, + defaultGrant: "default-grant", + client: &fakeEmailExtClient{ + getFolders: func(context.Context, string) ([]domain.Folder, error) { + return nil, clientErr + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { requireRPCErrorCode(t, resp, InternalError) }, + }, + { + name: "missing default grant errors", + method: "email.folder.list", + params: `{}`, + defaultGrant: "", + client: &fakeEmailExtClient{}, + assert: func(t *testing.T, resp rpcTestResponse) { requireRPCErrorCode(t, resp, InvalidParams) }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewDispatcher() + RegisterEmailExtHandlers(d, tt.client, tt.defaultGrant) + + raw := []byte(`{"jsonrpc":"2.0","id":1,"method":"` + tt.method + `","params":` + tt.params + `}`) + got := d.Dispatch(context.Background(), raw) + if got == nil { + t.Fatal("Dispatch() = nil, want response") + } + var resp rpcTestResponse + if err := json.Unmarshal(got, &resp); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + tt.assert(t, resp) + }) + } +} + +func TestRegisterEmailExtHandlers_AttachmentDownloadClosesBody(t *testing.T) { + tracker := &trackingReadCloser{Reader: strings.NewReader("data")} + client := &fakeEmailExtClient{ + downloadAttachment: func(context.Context, string, string, string) (io.ReadCloser, error) { + return tracker, nil + }, + } + + d := NewDispatcher() + RegisterEmailExtHandlers(d, client, "default-grant") + + raw := []byte(`{"jsonrpc":"2.0","id":1,"method":"email.attachment.download","params":{"message_id":"msg-1","attachment_id":"att-1"}}`) + got := d.Dispatch(context.Background(), raw) + if got == nil { + t.Fatal("Dispatch() = nil, want response") + } + var resp rpcTestResponse + if err := json.Unmarshal(got, &resp); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + requireNoRPCError(t, resp) + if !tracker.closed { + t.Fatal("download body was not closed") + } +} diff --git a/internal/adapters/rpcserver/handlers_email_test.go b/internal/adapters/rpcserver/handlers_email_test.go new file mode 100644 index 0000000..6b5dd16 --- /dev/null +++ b/internal/adapters/rpcserver/handlers_email_test.go @@ -0,0 +1,236 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type fakeEmailClient struct { + ports.MessageClient + + getMessagesWithCursor func(context.Context, string, *domain.MessageQueryParams) (*domain.MessageListResponse, error) + getMessage func(context.Context, string, string) (*domain.Message, error) +} + +func (f *fakeEmailClient) GetMessagesWithCursor(ctx context.Context, grantID string, params *domain.MessageQueryParams) (*domain.MessageListResponse, error) { + if f.getMessagesWithCursor == nil { + return nil, errors.New("unexpected GetMessagesWithCursor") + } + return f.getMessagesWithCursor(ctx, grantID, params) +} + +func (f *fakeEmailClient) GetMessage(ctx context.Context, grantID, messageID string) (*domain.Message, error) { + if f.getMessage == nil { + return nil, errors.New("unexpected GetMessage") + } + return f.getMessage(ctx, grantID, messageID) +} + +type rpcTestResponse struct { + JSONRPC string `json:"jsonrpc"` + ID json.RawMessage `json:"id,omitempty"` + Result json.RawMessage `json:"result,omitempty"` + Error *RPCError `json:"error,omitempty"` +} + +func TestRegisterEmailHandlers(t *testing.T) { + clientErr := errors.New("client unavailable") + + tests := []struct { + name string + method string + params string + defaultGrant string + client *fakeEmailClient + assert func(*testing.T, rpcTestResponse) + }{ + { + name: "email.list returns messages and next cursor", + method: "email.list", + params: `{"limit":2}`, + defaultGrant: "default-grant", + client: &fakeEmailClient{ + getMessagesWithCursor: func(ctx context.Context, grantID string, params *domain.MessageQueryParams) (*domain.MessageListResponse, error) { + if grantID != "default-grant" { + t.Fatalf("grantID = %q, want %q", grantID, "default-grant") + } + return &domain.MessageListResponse{ + Data: []domain.Message{ + {ID: "msg-1", Subject: "Hello"}, + {ID: "msg-2", Subject: "World"}, + }, + Pagination: domain.Pagination{NextCursor: "cursor-2", HasMore: true}, + }, nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + + var result struct { + Messages []domain.Message `json:"messages"` + NextCursor string `json:"next_cursor"` + HasMore bool `json:"has_more"` + } + unmarshalResult(t, resp, &result) + if len(result.Messages) != 2 || result.Messages[0].ID != "msg-1" || result.Messages[1].ID != "msg-2" { + t.Fatalf("messages = %#v, want msg-1 and msg-2", result.Messages) + } + if result.NextCursor != "cursor-2" { + t.Fatalf("next_cursor = %q, want %q", result.NextCursor, "cursor-2") + } + if !result.HasMore { + t.Fatal("has_more = false, want true") + } + }, + }, + { + name: "email.list forwards query params and request grant", + method: "email.list", + params: `{"grant_id":"request-grant","limit":25,"page_token":"cursor-1","received_after":1710000000}`, + defaultGrant: "default-grant", + client: &fakeEmailClient{ + getMessagesWithCursor: func(ctx context.Context, grantID string, params *domain.MessageQueryParams) (*domain.MessageListResponse, error) { + if grantID != "request-grant" { + t.Fatalf("grantID = %q, want %q", grantID, "request-grant") + } + if params.Limit != 25 { + t.Fatalf("Limit = %d, want 25", params.Limit) + } + if params.PageToken != "cursor-1" { + t.Fatalf("PageToken = %q, want %q", params.PageToken, "cursor-1") + } + if params.ReceivedAfter != 1710000000 { + t.Fatalf("ReceivedAfter = %d, want 1710000000", params.ReceivedAfter) + } + return &domain.MessageListResponse{}, nil + }, + }, + assert: requireNoRPCError, + }, + { + name: "email.list missing grant returns invalid params", + method: "email.list", + params: `{}`, + client: &fakeEmailClient{}, + assert: func(t *testing.T, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + }, + }, + { + name: "email.get with message_id returns the message", + method: "email.get", + params: `{"grant_id":"grant-1","message_id":"msg-1"}`, + client: &fakeEmailClient{ + getMessage: func(ctx context.Context, grantID, messageID string) (*domain.Message, error) { + if grantID != "grant-1" { + t.Fatalf("grantID = %q, want %q", grantID, "grant-1") + } + if messageID != "msg-1" { + t.Fatalf("messageID = %q, want %q", messageID, "msg-1") + } + return &domain.Message{ID: "msg-1", Subject: "Hello"}, nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + + var msg domain.Message + unmarshalResult(t, resp, &msg) + if msg.ID != "msg-1" || msg.Subject != "Hello" { + t.Fatalf("message = %#v, want msg-1 Hello", msg) + } + }, + }, + { + name: "email.get missing message_id returns invalid params", + method: "email.get", + params: `{"grant_id":"grant-1"}`, + defaultGrant: "grant-1", + client: &fakeEmailClient{}, + assert: func(t *testing.T, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + }, + }, + { + name: "client error maps to internal error", + method: "email.get", + params: `{"message_id":"msg-1"}`, + defaultGrant: "default-grant", + client: &fakeEmailClient{ + getMessage: func(ctx context.Context, grantID, messageID string) (*domain.Message, error) { + if grantID != "default-grant" { + t.Fatalf("grantID = %q, want %q", grantID, "default-grant") + } + if messageID != "msg-1" { + t.Fatalf("messageID = %q, want %q", messageID, "msg-1") + } + return nil, clientErr + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InternalError) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewDispatcher() + RegisterEmailHandlers(d, tt.client, tt.defaultGrant) + + resp := dispatchEmailRequest(t, d, tt.method, tt.params) + tt.assert(t, resp) + }) + } +} + +func dispatchEmailRequest(t *testing.T, d *Dispatcher, method, params string) rpcTestResponse { + t.Helper() + + raw := []byte(`{"jsonrpc":"2.0","id":1,"method":"` + method + `","params":` + params + `}`) + got := d.Dispatch(context.Background(), raw) + if got == nil { + t.Fatal("Dispatch() = nil, want response") + } + + var resp rpcTestResponse + if err := json.Unmarshal(got, &resp); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + if resp.JSONRPC != "2.0" { + t.Fatalf("JSONRPC = %q, want %q", resp.JSONRPC, "2.0") + } + return resp +} + +func requireNoRPCError(t *testing.T, resp rpcTestResponse) { + t.Helper() + + if resp.Error != nil { + t.Fatalf("Error = %#v, want nil", resp.Error) + } +} + +func requireRPCErrorCode(t *testing.T, resp rpcTestResponse, want int) { + t.Helper() + + if resp.Error == nil { + t.Fatal("Error = nil, want RPC error") + } + if resp.Error.Code != want { + t.Fatalf("Error.Code = %d, want %d", resp.Error.Code, want) + } +} + +func unmarshalResult(t *testing.T, resp rpcTestResponse, dest any) { + t.Helper() + + if err := json.Unmarshal(resp.Result, dest); err != nil { + t.Fatalf("unmarshal result: %v", err) + } +} diff --git a/internal/adapters/rpcserver/handlers_email_write.go b/internal/adapters/rpcserver/handlers_email_write.go new file mode 100644 index 0000000..2123c9b --- /dev/null +++ b/internal/adapters/rpcserver/handlers_email_write.go @@ -0,0 +1,189 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type emailSendParams struct { + GrantID string `json:"grant_id,omitempty"` + domain.SendMessageRequest +} + +type emailUpdateParams struct { + GrantID string `json:"grant_id,omitempty"` + MessageID string `json:"message_id"` + domain.UpdateMessageRequest +} + +type emailDeleteParams struct { + GrantID string `json:"grant_id,omitempty"` + MessageID string `json:"message_id"` +} + +type draftCreateParams struct { + GrantID string `json:"grant_id,omitempty"` + domain.CreateDraftRequest +} + +type draftUpdateParams struct { + GrantID string `json:"grant_id,omitempty"` + DraftID string `json:"draft_id"` + domain.CreateDraftRequest +} + +type draftDeleteParams struct { + GrantID string `json:"grant_id,omitempty"` + DraftID string `json:"draft_id"` +} + +type draftSendParams struct { + GrantID string `json:"grant_id,omitempty"` + DraftID string `json:"draft_id"` + domain.SendDraftRequest +} + +func RegisterEmailWriteHandlers(d *Dispatcher, client ports.MessageClient, defaultGrant string) { + d.Register("email.send", func(ctx context.Context, params json.RawMessage) (any, error) { + var p emailSendParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + msg, err := client.SendMessage(ctx, grantID, &p.SendMessageRequest) + if err != nil { + return nil, fmt.Errorf("email.send: %w", err) + } + return msg, nil + }) + + d.Register("email.update", func(ctx context.Context, params json.RawMessage) (any, error) { + var p emailUpdateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.MessageID == "" { + return nil, NewRPCError(InvalidParams, "message_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + msg, err := client.UpdateMessage(ctx, grantID, p.MessageID, &p.UpdateMessageRequest) + if err != nil { + return nil, fmt.Errorf("email.update: %w", err) + } + return msg, nil + }) + + d.Register("email.delete", func(ctx context.Context, params json.RawMessage) (any, error) { + var p emailDeleteParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.MessageID == "" { + return nil, NewRPCError(InvalidParams, "message_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + if err := client.DeleteMessage(ctx, grantID, p.MessageID); err != nil { + return nil, fmt.Errorf("email.delete: %w", err) + } + return deletedResult{Deleted: true}, nil + }) + + d.Register("draft.create", func(ctx context.Context, params json.RawMessage) (any, error) { + var p draftCreateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + draft, err := client.CreateDraft(ctx, grantID, &p.CreateDraftRequest) + if err != nil { + return nil, fmt.Errorf("draft.create: %w", err) + } + return draft, nil + }) + + d.Register("draft.update", func(ctx context.Context, params json.RawMessage) (any, error) { + var p draftUpdateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.DraftID == "" { + return nil, NewRPCError(InvalidParams, "draft_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + draft, err := client.UpdateDraft(ctx, grantID, p.DraftID, &p.CreateDraftRequest) + if err != nil { + return nil, fmt.Errorf("draft.update: %w", err) + } + return draft, nil + }) + + d.Register("draft.delete", func(ctx context.Context, params json.RawMessage) (any, error) { + var p draftDeleteParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.DraftID == "" { + return nil, NewRPCError(InvalidParams, "draft_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + if err := client.DeleteDraft(ctx, grantID, p.DraftID); err != nil { + return nil, fmt.Errorf("draft.delete: %w", err) + } + return deletedResult{Deleted: true}, nil + }) + + d.Register("draft.send", func(ctx context.Context, params json.RawMessage) (any, error) { + var p draftSendParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.DraftID == "" { + return nil, NewRPCError(InvalidParams, "draft_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + msg, err := client.SendDraft(ctx, grantID, p.DraftID, &p.SendDraftRequest) + if err != nil { + return nil, fmt.Errorf("draft.send: %w", err) + } + return msg, nil + }) +} diff --git a/internal/adapters/rpcserver/handlers_email_write_test.go b/internal/adapters/rpcserver/handlers_email_write_test.go new file mode 100644 index 0000000..54425eb --- /dev/null +++ b/internal/adapters/rpcserver/handlers_email_write_test.go @@ -0,0 +1,287 @@ +package rpcserver + +import ( + "context" + "errors" + "testing" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type fakeEmailWriteClient struct { + ports.MessageClient + + err error + + sendMessageResult *domain.Message + updateMessageResult *domain.Message + createDraftResult *domain.Draft + updateDraftResult *domain.Draft + sendDraftResult *domain.Message + + sendGrantID string + sendRequest *domain.SendMessageRequest + updateGrantID string + updateMessageID string + updateRequest *domain.UpdateMessageRequest + deleteGrantID string + deleteMessageID string + createDraftGrantID string + createDraftRequest *domain.CreateDraftRequest + updateDraftGrantID string + updateDraftID string + updateDraftRequest *domain.CreateDraftRequest + deleteDraftGrantID string + deleteDraftID string + sendDraftGrantID string + sendDraftID string + sendDraftRequest *domain.SendDraftRequest +} + +func (f *fakeEmailWriteClient) SendMessage(ctx context.Context, grantID string, req *domain.SendMessageRequest) (*domain.Message, error) { + f.sendGrantID = grantID + f.sendRequest = req + return f.sendMessageResult, f.err +} + +func (f *fakeEmailWriteClient) UpdateMessage(ctx context.Context, grantID, messageID string, req *domain.UpdateMessageRequest) (*domain.Message, error) { + f.updateGrantID = grantID + f.updateMessageID = messageID + f.updateRequest = req + return f.updateMessageResult, f.err +} + +func (f *fakeEmailWriteClient) DeleteMessage(ctx context.Context, grantID, messageID string) error { + f.deleteGrantID = grantID + f.deleteMessageID = messageID + return f.err +} + +func (f *fakeEmailWriteClient) CreateDraft(ctx context.Context, grantID string, req *domain.CreateDraftRequest) (*domain.Draft, error) { + f.createDraftGrantID = grantID + f.createDraftRequest = req + return f.createDraftResult, f.err +} + +func (f *fakeEmailWriteClient) UpdateDraft(ctx context.Context, grantID, draftID string, req *domain.CreateDraftRequest) (*domain.Draft, error) { + f.updateDraftGrantID = grantID + f.updateDraftID = draftID + f.updateDraftRequest = req + return f.updateDraftResult, f.err +} + +func (f *fakeEmailWriteClient) DeleteDraft(ctx context.Context, grantID, draftID string) error { + f.deleteDraftGrantID = grantID + f.deleteDraftID = draftID + return f.err +} + +func (f *fakeEmailWriteClient) SendDraft(ctx context.Context, grantID, draftID string, req *domain.SendDraftRequest) (*domain.Message, error) { + f.sendDraftGrantID = grantID + f.sendDraftID = draftID + f.sendDraftRequest = req + return f.sendDraftResult, f.err +} + +func TestRegisterEmailWriteHandlers(t *testing.T) { + unread := false + starred := true + + tests := []struct { + name string + method string + params string + client *fakeEmailWriteClient + assert func(*testing.T, *fakeEmailWriteClient, rpcTestResponse) + }{ + { + name: "email.send forwards request and returns message", + method: "email.send", + params: `{"grant_id":"grant-1","subject":"Hello","body":"World","to":[{"email":"ada@example.com","name":"Ada"}]}`, + client: &fakeEmailWriteClient{ + sendMessageResult: &domain.Message{ID: "msg-1", Subject: "Hello"}, + }, + assert: func(t *testing.T, client *fakeEmailWriteClient, resp rpcTestResponse) { + requireNoRPCError(t, resp) + if client.sendGrantID != "grant-1" { + t.Fatalf("grantID = %q, want grant-1", client.sendGrantID) + } + if client.sendRequest == nil { + t.Fatal("sendRequest = nil, want request") + } + if client.sendRequest.Subject != "Hello" { + t.Fatalf("Subject = %q, want Hello", client.sendRequest.Subject) + } + if len(client.sendRequest.To) != 1 || client.sendRequest.To[0].Email != "ada@example.com" { + t.Fatalf("To = %#v, want ada@example.com", client.sendRequest.To) + } + + var msg domain.Message + unmarshalResult(t, resp, &msg) + if msg.ID != "msg-1" || msg.Subject != "Hello" { + t.Fatalf("message = %#v, want msg-1 Hello", msg) + } + }, + }, + { + name: "email.update forwards request and returns message", + method: "email.update", + params: `{"message_id":"msg-1","unread":false,"starred":true,"folders":["sent"]}`, + client: &fakeEmailWriteClient{ + updateMessageResult: &domain.Message{ID: "msg-1", Starred: true}, + }, + assert: func(t *testing.T, client *fakeEmailWriteClient, resp rpcTestResponse) { + requireNoRPCError(t, resp) + if client.updateGrantID != "default-grant" || client.updateMessageID != "msg-1" { + t.Fatalf("update args = %q/%q, want default-grant/msg-1", client.updateGrantID, client.updateMessageID) + } + if client.updateRequest == nil || client.updateRequest.Unread == nil || *client.updateRequest.Unread != unread { + t.Fatalf("Unread = %#v, want false", client.updateRequest) + } + if client.updateRequest.Starred == nil || *client.updateRequest.Starred != starred { + t.Fatalf("Starred = %#v, want true", client.updateRequest.Starred) + } + if len(client.updateRequest.Folders) != 1 || client.updateRequest.Folders[0] != "sent" { + t.Fatalf("Folders = %#v, want sent", client.updateRequest.Folders) + } + }, + }, + { + name: "email.delete deletes and returns deleted", + method: "email.delete", + params: `{"grant_id":"grant-1","message_id":"msg-1"}`, + client: &fakeEmailWriteClient{}, + assert: func(t *testing.T, client *fakeEmailWriteClient, resp rpcTestResponse) { + requireNoRPCError(t, resp) + if client.deleteGrantID != "grant-1" || client.deleteMessageID != "msg-1" { + t.Fatalf("delete args = %q/%q, want grant-1/msg-1", client.deleteGrantID, client.deleteMessageID) + } + var result deletedResult + unmarshalResult(t, resp, &result) + if !result.Deleted { + t.Fatal("deleted = false, want true") + } + }, + }, + { + name: "draft.create forwards request and returns draft", + method: "draft.create", + params: `{"grant_id":"grant-1","subject":"Draft","body":"Body","to":[{"email":"grace@example.com"}]}`, + client: &fakeEmailWriteClient{ + createDraftResult: &domain.Draft{ID: "draft-1", Subject: "Draft"}, + }, + assert: func(t *testing.T, client *fakeEmailWriteClient, resp rpcTestResponse) { + requireNoRPCError(t, resp) + if client.createDraftGrantID != "grant-1" { + t.Fatalf("grantID = %q, want grant-1", client.createDraftGrantID) + } + if client.createDraftRequest == nil || client.createDraftRequest.Subject != "Draft" { + t.Fatalf("createDraftRequest = %#v, want subject Draft", client.createDraftRequest) + } + + var draft domain.Draft + unmarshalResult(t, resp, &draft) + if draft.ID != "draft-1" || draft.Subject != "Draft" { + t.Fatalf("draft = %#v, want draft-1 Draft", draft) + } + }, + }, + { + name: "draft.update forwards request and returns draft", + method: "draft.update", + params: `{"draft_id":"draft-1","subject":"Updated","body":"Body","signature_id":"sig-1"}`, + client: &fakeEmailWriteClient{ + updateDraftResult: &domain.Draft{ID: "draft-1", Subject: "Updated"}, + }, + assert: func(t *testing.T, client *fakeEmailWriteClient, resp rpcTestResponse) { + requireNoRPCError(t, resp) + if client.updateDraftGrantID != "default-grant" || client.updateDraftID != "draft-1" { + t.Fatalf("update draft args = %q/%q, want default-grant/draft-1", client.updateDraftGrantID, client.updateDraftID) + } + if client.updateDraftRequest == nil || client.updateDraftRequest.SignatureID != "sig-1" { + t.Fatalf("SignatureID = %#v, want sig-1", client.updateDraftRequest) + } + }, + }, + { + name: "draft.delete deletes and returns deleted", + method: "draft.delete", + params: `{"grant_id":"grant-1","draft_id":"draft-1"}`, + client: &fakeEmailWriteClient{}, + assert: func(t *testing.T, client *fakeEmailWriteClient, resp rpcTestResponse) { + requireNoRPCError(t, resp) + if client.deleteDraftGrantID != "grant-1" || client.deleteDraftID != "draft-1" { + t.Fatalf("delete draft args = %q/%q, want grant-1/draft-1", client.deleteDraftGrantID, client.deleteDraftID) + } + var result deletedResult + unmarshalResult(t, resp, &result) + if !result.Deleted { + t.Fatal("deleted = false, want true") + } + }, + }, + { + name: "draft.send forwards request and returns message", + method: "draft.send", + params: `{"draft_id":"draft-1","signature_id":"sig-1"}`, + client: &fakeEmailWriteClient{ + sendDraftResult: &domain.Message{ID: "msg-1"}, + }, + assert: func(t *testing.T, client *fakeEmailWriteClient, resp rpcTestResponse) { + requireNoRPCError(t, resp) + if client.sendDraftGrantID != "default-grant" || client.sendDraftID != "draft-1" { + t.Fatalf("send draft args = %q/%q, want default-grant/draft-1", client.sendDraftGrantID, client.sendDraftID) + } + if client.sendDraftRequest == nil || client.sendDraftRequest.SignatureID != "sig-1" { + t.Fatalf("SignatureID = %#v, want sig-1", client.sendDraftRequest) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewDispatcher() + RegisterEmailWriteHandlers(d, tt.client, "default-grant") + + resp := dispatchEmailRequest(t, d, tt.method, tt.params) + tt.assert(t, tt.client, resp) + }) + } +} + +func TestRegisterEmailWriteHandlers_InvalidParams(t *testing.T) { + tests := []struct { + name string + method string + params string + }{ + {name: "email.send missing grant", method: "email.send", params: `{}`}, + {name: "email.update missing message_id", method: "email.update", params: `{"grant_id":"grant-1"}`}, + {name: "email.delete missing message_id", method: "email.delete", params: `{"grant_id":"grant-1"}`}, + {name: "draft.update missing draft_id", method: "draft.update", params: `{"grant_id":"grant-1"}`}, + {name: "draft.delete missing draft_id", method: "draft.delete", params: `{"grant_id":"grant-1"}`}, + {name: "draft.send missing draft_id", method: "draft.send", params: `{"grant_id":"grant-1"}`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewDispatcher() + RegisterEmailWriteHandlers(d, &fakeEmailWriteClient{}, "") + + resp := dispatchEmailRequest(t, d, tt.method, tt.params) + requireRPCErrorCode(t, resp, InvalidParams) + }) + } +} + +func TestRegisterEmailWriteHandlers_ClientError(t *testing.T) { + d := NewDispatcher() + RegisterEmailWriteHandlers(d, &fakeEmailWriteClient{ + err: errors.New("client unavailable"), + }, "grant-1") + + resp := dispatchEmailRequest(t, d, "email.delete", `{"message_id":"msg-1"}`) + requireRPCErrorCode(t, resp, InternalError) +} diff --git a/internal/adapters/rpcserver/handlers_local.go b/internal/adapters/rpcserver/handlers_local.go new file mode 100644 index 0000000..cc26ab5 --- /dev/null +++ b/internal/adapters/rpcserver/handlers_local.go @@ -0,0 +1,122 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type agentAccountGetParams struct { + GrantID string `json:"grant_id"` +} + +type agentAccountListResult struct { + Accounts []domain.AgentAccount `json:"accounts"` +} + +type grantListResult struct { + Grants []domain.GrantInfo `json:"grants"` +} + +type configLoader interface { + Load() (*domain.Config, error) +} + +type configReadResult struct { + Region string `json:"region"` + DefaultGrant string `json:"default_grant"` + CallbackPort int `json:"callback_port"` + TUITheme string `json:"tui_theme"` + API *configReadAPI `json:"api,omitempty"` + WorkingHours *domain.WorkingHoursConfig `json:"working_hours"` + AIConfigured bool `json:"ai_configured"` + GPGConfigured bool `json:"gpg_configured"` + DashboardConfigured bool `json:"dashboard_configured"` +} + +type configReadAPI struct { + BaseURL string `json:"base_url"` + Timeout string `json:"timeout"` +} + +func RegisterAgentHandlers(d *Dispatcher, client ports.AgentClient) { + d.Register("agentAccount.list", func(ctx context.Context, params json.RawMessage) (any, error) { + var p struct{} + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + accounts, err := client.ListAgentAccounts(ctx) + if err != nil { + return nil, fmt.Errorf("agentAccount.list: %w", err) + } + return agentAccountListResult{Accounts: accounts}, nil + }) + + d.Register("agentAccount.get", func(ctx context.Context, params json.RawMessage) (any, error) { + var p agentAccountGetParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.GrantID == "" { + return nil, NewRPCError(InvalidParams, "grant_id required", nil) + } + + account, err := client.GetAgentAccount(ctx, p.GrantID) + if err != nil { + return nil, fmt.Errorf("agentAccount.get: %w", err) + } + return account, nil + }) +} + +func RegisterGrantHandlers(d *Dispatcher, store ports.GrantStore) { + d.Register("grant.list", func(_ context.Context, params json.RawMessage) (any, error) { + var p struct{} + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + grants, err := store.ListGrants() + if err != nil { + return nil, fmt.Errorf("grant.list: %w", err) + } + return grantListResult{Grants: grants}, nil + }) +} + +func RegisterConfigHandlers(d *Dispatcher, loader configLoader) { + d.Register("config.read", func(_ context.Context, params json.RawMessage) (any, error) { + var p struct{} + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + cfg, err := loader.Load() + if err != nil { + return nil, fmt.Errorf("config.read: %w", err) + } + + result := configReadResult{ + Region: cfg.Region, + DefaultGrant: cfg.DefaultGrant, + CallbackPort: cfg.CallbackPort, + TUITheme: cfg.TUITheme, + WorkingHours: cfg.WorkingHours, + AIConfigured: cfg.AI != nil, + GPGConfigured: cfg.GPG != nil, + DashboardConfigured: cfg.Dashboard != nil, + } + if cfg.API != nil { + result.API = &configReadAPI{ + BaseURL: cfg.API.BaseURL, + Timeout: cfg.API.Timeout, + } + } + + return result, nil + }) +} diff --git a/internal/adapters/rpcserver/handlers_local_test.go b/internal/adapters/rpcserver/handlers_local_test.go new file mode 100644 index 0000000..9819e57 --- /dev/null +++ b/internal/adapters/rpcserver/handlers_local_test.go @@ -0,0 +1,247 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type fakeLocalAgentClient struct { + ports.AgentClient + + listAgentAccounts func(context.Context) ([]domain.AgentAccount, error) + getAgentAccount func(context.Context, string) (*domain.AgentAccount, error) +} + +func (f *fakeLocalAgentClient) ListAgentAccounts(ctx context.Context) ([]domain.AgentAccount, error) { + if f.listAgentAccounts == nil { + return nil, errors.New("unexpected ListAgentAccounts") + } + return f.listAgentAccounts(ctx) +} + +func (f *fakeLocalAgentClient) GetAgentAccount(ctx context.Context, grantID string) (*domain.AgentAccount, error) { + if f.getAgentAccount == nil { + return nil, errors.New("unexpected GetAgentAccount") + } + return f.getAgentAccount(ctx, grantID) +} + +type fakeLocalGrantStore struct { + ports.GrantStore + + listGrants func() ([]domain.GrantInfo, error) +} + +func (f *fakeLocalGrantStore) ListGrants() ([]domain.GrantInfo, error) { + if f.listGrants == nil { + return nil, errors.New("unexpected ListGrants") + } + return f.listGrants() +} + +type fakeLocalConfigLoader struct { + load func() (*domain.Config, error) +} + +func (f *fakeLocalConfigLoader) Load() (*domain.Config, error) { + if f.load == nil { + return nil, errors.New("unexpected Load") + } + return f.load() +} + +func TestRegisterAgentHandlers_Local(t *testing.T) { + tests := []struct { + name string + method string + params string + client *fakeLocalAgentClient + assert func(*testing.T, rpcTestResponse) + }{ + { + name: "agentAccount.list returns accounts", + method: "agentAccount.list", + params: `{}`, + client: &fakeLocalAgentClient{ + listAgentAccounts: func(ctx context.Context) ([]domain.AgentAccount, error) { + return []domain.AgentAccount{ + {ID: "grant-1", Provider: domain.ProviderNylas, Email: "agent@example.com", Name: "Agent One"}, + }, nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + + var result struct { + Accounts []domain.AgentAccount `json:"accounts"` + } + unmarshalResult(t, resp, &result) + if len(result.Accounts) != 1 || result.Accounts[0].ID != "grant-1" || result.Accounts[0].Email != "agent@example.com" { + t.Fatalf("accounts = %#v, want grant-1 agent@example.com", result.Accounts) + } + }, + }, + { + name: "agentAccount.get with grant_id returns account", + method: "agentAccount.get", + params: `{"grant_id":"grant-1"}`, + client: &fakeLocalAgentClient{ + getAgentAccount: func(ctx context.Context, grantID string) (*domain.AgentAccount, error) { + if grantID != "grant-1" { + t.Fatalf("grantID = %q, want grant-1", grantID) + } + return &domain.AgentAccount{ID: "grant-1", Provider: domain.ProviderNylas, Email: "agent@example.com"}, nil + }, + }, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + + var account domain.AgentAccount + unmarshalResult(t, resp, &account) + if account.ID != "grant-1" || account.Provider != domain.ProviderNylas { + t.Fatalf("account = %#v, want grant-1 nylas", account) + } + }, + }, + { + name: "agentAccount.get missing grant_id returns invalid params", + method: "agentAccount.get", + params: `{}`, + client: &fakeLocalAgentClient{}, + assert: func(t *testing.T, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewDispatcher() + RegisterAgentHandlers(d, tt.client) + + resp := dispatchLocalRequest(t, d, tt.method, tt.params) + tt.assert(t, resp) + }) + } +} + +func TestRegisterGrantHandlers_Local(t *testing.T) { + d := NewDispatcher() + RegisterGrantHandlers(d, &fakeLocalGrantStore{ + listGrants: func() ([]domain.GrantInfo, error) { + return []domain.GrantInfo{ + {ID: "grant-1", Email: "user@example.com", Provider: domain.ProviderGoogle}, + }, nil + }, + }) + + resp := dispatchLocalRequest(t, d, "grant.list", `{}`) + requireNoRPCError(t, resp) + + var result struct { + Grants []domain.GrantInfo `json:"grants"` + } + unmarshalResult(t, resp, &result) + if len(result.Grants) != 1 || result.Grants[0].ID != "grant-1" || result.Grants[0].Provider != domain.ProviderGoogle { + t.Fatalf("grants = %#v, want grant-1 google", result.Grants) + } +} + +func TestRegisterConfigHandlers_Local(t *testing.T) { + d := NewDispatcher() + RegisterConfigHandlers(d, &fakeLocalConfigLoader{ + load: func() (*domain.Config, error) { + return &domain.Config{ + Region: "eu", + DefaultGrant: "grant-1", + CallbackPort: 9008, + Grants: []domain.GrantInfo{{ID: "hidden-grant", Email: "hidden@example.com", Provider: domain.ProviderGoogle}}, + API: &domain.APIConfig{BaseURL: "https://api.example.test", Timeout: "30s"}, + TUITheme: "catppuccin", + WorkingHours: &domain.WorkingHoursConfig{ + Default: &domain.DaySchedule{Enabled: true, Start: "09:00", End: "17:00"}, + }, + AI: &domain.AIConfig{DefaultProvider: "openai"}, + GPG: &domain.GPGConfig{DefaultKey: "key-id"}, + Dashboard: &domain.DashboardConfig{AccountBaseURL: "https://dashboard.example.test"}, + }, nil + }, + }) + + resp := dispatchLocalRequest(t, d, "config.read", `{}`) + requireNoRPCError(t, resp) + + var result map[string]json.RawMessage + unmarshalResult(t, resp, &result) + requireJSONBool(t, result, "ai_configured", true) + requireJSONBool(t, result, "gpg_configured", true) + requireJSONBool(t, result, "dashboard_configured", true) + + for _, key := range []string{"ai", "gpg", "dashboard", "grants"} { + if _, ok := result[key]; ok { + t.Fatalf("result contains %q key: %s", key, resp.Result) + } + } + + var whitelisted struct { + Region string `json:"region"` + DefaultGrant string `json:"default_grant"` + CallbackPort int `json:"callback_port"` + TUITheme string `json:"tui_theme"` + API struct { + BaseURL string `json:"base_url"` + Timeout string `json:"timeout"` + } `json:"api"` + WorkingHours *domain.WorkingHoursConfig `json:"working_hours"` + } + unmarshalResult(t, resp, &whitelisted) + if whitelisted.Region != "eu" || whitelisted.DefaultGrant != "grant-1" || whitelisted.CallbackPort != 9008 || whitelisted.TUITheme != "catppuccin" { + t.Fatalf("config result = %+v, want whitelisted scalar fields", whitelisted) + } + if whitelisted.API.BaseURL != "https://api.example.test" || whitelisted.API.Timeout != "30s" { + t.Fatalf("api = %+v, want base_url and timeout", whitelisted.API) + } + if whitelisted.WorkingHours == nil { + t.Fatal("working_hours = nil, want configured working hours") + } +} + +func dispatchLocalRequest(t *testing.T, d *Dispatcher, method, params string) rpcTestResponse { + t.Helper() + + raw := []byte(`{"jsonrpc":"2.0","id":1,"method":"` + method + `","params":` + params + `}`) + got := d.Dispatch(context.Background(), raw) + if got == nil { + t.Fatal("Dispatch() = nil, want response") + } + + var resp rpcTestResponse + if err := json.Unmarshal(got, &resp); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + if resp.JSONRPC != "2.0" { + t.Fatalf("JSONRPC = %q, want %q", resp.JSONRPC, "2.0") + } + return resp +} + +func requireJSONBool(t *testing.T, fields map[string]json.RawMessage, key string, want bool) { + t.Helper() + + raw, ok := fields[key] + if !ok { + t.Fatalf("missing %q key", key) + } + var got bool + if err := json.Unmarshal(raw, &got); err != nil { + t.Fatalf("unmarshal %q: %v", key, err) + } + if got != want { + t.Fatalf("%s = %t, want %t", key, got, want) + } +} diff --git a/internal/adapters/rpcserver/handlers_notetaker.go b/internal/adapters/rpcserver/handlers_notetaker.go new file mode 100644 index 0000000..f60750d --- /dev/null +++ b/internal/adapters/rpcserver/handlers_notetaker.go @@ -0,0 +1,195 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type notetakerListParams struct { + GrantID string `json:"grant_id,omitempty"` + domain.NotetakerQueryParams +} + +type notetakerListResult struct { + Notetakers []domain.Notetaker `json:"notetakers"` +} + +type notetakerGetParams struct { + GrantID string `json:"grant_id,omitempty"` + NotetakerID string `json:"notetaker_id"` +} + +type notetakerCreateParams struct { + GrantID string `json:"grant_id,omitempty"` + domain.CreateNotetakerRequest +} + +type notetakerUpdateParams struct { + GrantID string `json:"grant_id,omitempty"` + NotetakerID string `json:"notetaker_id"` + domain.UpdateNotetakerRequest +} + +type notetakerDeleteParams struct { + GrantID string `json:"grant_id,omitempty"` + NotetakerID string `json:"notetaker_id"` +} + +type notetakerLeaveParams struct { + GrantID string `json:"grant_id,omitempty"` + NotetakerID string `json:"notetaker_id"` +} + +type notetakerMediaParams struct { + GrantID string `json:"grant_id,omitempty"` + NotetakerID string `json:"notetaker_id"` +} + +type leftResult struct { + Left bool `json:"left"` +} + +func RegisterNotetakerHandlers(d *Dispatcher, client ports.NotetakerClient, defaultGrant string) { + d.Register("notetaker.list", func(ctx context.Context, params json.RawMessage) (any, error) { + var p notetakerListParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + notetakers, err := client.ListNotetakers(ctx, grantID, &p.NotetakerQueryParams) + if err != nil { + return nil, fmt.Errorf("notetaker.list: %w", err) + } + return notetakerListResult{Notetakers: notetakers}, nil + }) + + d.Register("notetaker.get", func(ctx context.Context, params json.RawMessage) (any, error) { + var p notetakerGetParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.NotetakerID == "" { + return nil, NewRPCError(InvalidParams, "notetaker_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + notetaker, err := client.GetNotetaker(ctx, grantID, p.NotetakerID) + if err != nil { + return nil, fmt.Errorf("notetaker.get: %w", err) + } + return notetaker, nil + }) + + d.Register("notetaker.create", func(ctx context.Context, params json.RawMessage) (any, error) { + var p notetakerCreateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + notetaker, err := client.CreateNotetaker(ctx, grantID, &p.CreateNotetakerRequest) + if err != nil { + return nil, fmt.Errorf("notetaker.create: %w", err) + } + return notetaker, nil + }) + + d.Register("notetaker.update", func(ctx context.Context, params json.RawMessage) (any, error) { + var p notetakerUpdateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.NotetakerID == "" { + return nil, NewRPCError(InvalidParams, "notetaker_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + notetaker, err := client.UpdateNotetaker(ctx, grantID, p.NotetakerID, &p.UpdateNotetakerRequest) + if err != nil { + return nil, fmt.Errorf("notetaker.update: %w", err) + } + return notetaker, nil + }) + + d.Register("notetaker.delete", func(ctx context.Context, params json.RawMessage) (any, error) { + var p notetakerDeleteParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.NotetakerID == "" { + return nil, NewRPCError(InvalidParams, "notetaker_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + if err := client.DeleteNotetaker(ctx, grantID, p.NotetakerID); err != nil { + return nil, fmt.Errorf("notetaker.delete: %w", err) + } + return deletedResult{Deleted: true}, nil + }) + + d.Register("notetaker.leave", func(ctx context.Context, params json.RawMessage) (any, error) { + var p notetakerLeaveParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.NotetakerID == "" { + return nil, NewRPCError(InvalidParams, "notetaker_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + if err := client.LeaveNotetaker(ctx, grantID, p.NotetakerID); err != nil { + return nil, fmt.Errorf("notetaker.leave: %w", err) + } + return leftResult{Left: true}, nil + }) + + d.Register("notetaker.media", func(ctx context.Context, params json.RawMessage) (any, error) { + var p notetakerMediaParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.NotetakerID == "" { + return nil, NewRPCError(InvalidParams, "notetaker_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + media, err := client.GetNotetakerMedia(ctx, grantID, p.NotetakerID) + if err != nil { + return nil, fmt.Errorf("notetaker.media: %w", err) + } + return media, nil + }) +} diff --git a/internal/adapters/rpcserver/handlers_notetaker_test.go b/internal/adapters/rpcserver/handlers_notetaker_test.go new file mode 100644 index 0000000..a288f38 --- /dev/null +++ b/internal/adapters/rpcserver/handlers_notetaker_test.go @@ -0,0 +1,463 @@ +package rpcserver + +import ( + "context" + "errors" + "testing" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type fakeNotetakerClient struct { + ports.NotetakerClient + + err error + + listResult []domain.Notetaker + getResult *domain.Notetaker + createResult *domain.Notetaker + updateResult *domain.Notetaker + mediaResult *domain.MediaData + + method string + grantID string + notetakerID string + listParams *domain.NotetakerQueryParams + createReq *domain.CreateNotetakerRequest + updateReq *domain.UpdateNotetakerRequest + deleteCalled bool + leaveCalled bool +} + +func (f *fakeNotetakerClient) ListNotetakers(ctx context.Context, grantID string, params *domain.NotetakerQueryParams) ([]domain.Notetaker, error) { + f.method = "list" + f.grantID = grantID + f.listParams = params + return f.listResult, f.err +} + +func (f *fakeNotetakerClient) GetNotetaker(ctx context.Context, grantID, notetakerID string) (*domain.Notetaker, error) { + f.method = "get" + f.grantID = grantID + f.notetakerID = notetakerID + return f.getResult, f.err +} + +func (f *fakeNotetakerClient) CreateNotetaker(ctx context.Context, grantID string, req *domain.CreateNotetakerRequest) (*domain.Notetaker, error) { + f.method = "create" + f.grantID = grantID + f.createReq = req + return f.createResult, f.err +} + +func (f *fakeNotetakerClient) UpdateNotetaker(ctx context.Context, grantID, notetakerID string, req *domain.UpdateNotetakerRequest) (*domain.Notetaker, error) { + f.method = "update" + f.grantID = grantID + f.notetakerID = notetakerID + f.updateReq = req + return f.updateResult, f.err +} + +func (f *fakeNotetakerClient) DeleteNotetaker(ctx context.Context, grantID, notetakerID string) error { + f.method = "delete" + f.grantID = grantID + f.notetakerID = notetakerID + f.deleteCalled = true + return f.err +} + +func (f *fakeNotetakerClient) LeaveNotetaker(ctx context.Context, grantID, notetakerID string) error { + f.method = "leave" + f.grantID = grantID + f.notetakerID = notetakerID + f.leaveCalled = true + return f.err +} + +func (f *fakeNotetakerClient) GetNotetakerMedia(ctx context.Context, grantID, notetakerID string) (*domain.MediaData, error) { + f.method = "media" + f.grantID = grantID + f.notetakerID = notetakerID + return f.mediaResult, f.err +} + +func TestRegisterNotetakerHandlers(t *testing.T) { + clientErr := errors.New("client unavailable") + enabled := true + + tests := []struct { + name string + method string + params string + defaultGrant string + client *fakeNotetakerClient + assert func(*testing.T, *fakeNotetakerClient, rpcTestResponse) + }{ + { + name: "notetaker.list returns notetakers", + method: "notetaker.list", + params: `{"limit":2,"page_token":"cursor-1","state":"complete"}`, + defaultGrant: "default-grant", + client: &fakeNotetakerClient{ + listResult: []domain.Notetaker{{ID: "nt-1", State: domain.NotetakerStateComplete}}, + }, + assert: func(t *testing.T, client *fakeNotetakerClient, resp rpcTestResponse) { + requireNoRPCError(t, resp) + if client.grantID != "default-grant" { + t.Fatalf("grantID = %q, want default-grant", client.grantID) + } + if client.listParams == nil || client.listParams.Limit != 2 || client.listParams.PageToken != "cursor-1" || client.listParams.State != "complete" { + t.Fatalf("listParams = %#v, want forwarded query params", client.listParams) + } + + var result notetakerListResult + unmarshalResult(t, resp, &result) + if len(result.Notetakers) != 1 || result.Notetakers[0].ID != "nt-1" { + t.Fatalf("notetakers = %#v, want nt-1", result.Notetakers) + } + }, + }, + { + name: "notetaker.list missing grant returns invalid params", + method: "notetaker.list", + params: `{}`, + client: &fakeNotetakerClient{}, + assert: func(t *testing.T, client *fakeNotetakerClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + if client.method != "" { + t.Fatalf("method = %q, want no client call", client.method) + } + }, + }, + { + name: "notetaker.list client error maps to internal error", + method: "notetaker.list", + params: `{}`, + defaultGrant: "default-grant", + client: &fakeNotetakerClient{err: clientErr}, + assert: func(t *testing.T, client *fakeNotetakerClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InternalError) + }, + }, + { + name: "notetaker.get returns notetaker", + method: "notetaker.get", + params: `{"grant_id":"grant-1","notetaker_id":"nt-1"}`, + client: &fakeNotetakerClient{getResult: &domain.Notetaker{ID: "nt-1", State: domain.NotetakerStateAttending}}, + assert: func(t *testing.T, client *fakeNotetakerClient, resp rpcTestResponse) { + requireNoRPCError(t, resp) + if client.grantID != "grant-1" || client.notetakerID != "nt-1" { + t.Fatalf("args = %q/%q, want grant-1/nt-1", client.grantID, client.notetakerID) + } + + var nt domain.Notetaker + unmarshalResult(t, resp, &nt) + if nt.ID != "nt-1" || nt.State != domain.NotetakerStateAttending { + t.Fatalf("notetaker = %#v, want nt-1 attending", nt) + } + }, + }, + { + name: "notetaker.get missing grant returns invalid params", + method: "notetaker.get", + params: `{"notetaker_id":"nt-1"}`, + client: &fakeNotetakerClient{}, + assert: func(t *testing.T, client *fakeNotetakerClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + }, + }, + { + name: "notetaker.get missing notetaker_id returns invalid params", + method: "notetaker.get", + params: `{"grant_id":"grant-1"}`, + defaultGrant: "default-grant", + client: &fakeNotetakerClient{}, + assert: func(t *testing.T, client *fakeNotetakerClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + if client.method != "" { + t.Fatalf("method = %q, want no client call", client.method) + } + }, + }, + { + name: "notetaker.get client error maps to internal error", + method: "notetaker.get", + params: `{"notetaker_id":"nt-1"}`, + defaultGrant: "default-grant", + client: &fakeNotetakerClient{err: clientErr}, + assert: func(t *testing.T, client *fakeNotetakerClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InternalError) + }, + }, + { + name: "notetaker.create forwards request", + method: "notetaker.create", + params: `{"meeting_link":"https://meet.example/abc","join_time":1710000000,"bot_config":{"name":"Nyla"}}`, + defaultGrant: "default-grant", + client: &fakeNotetakerClient{createResult: &domain.Notetaker{ID: "nt-1", MeetingLink: "https://meet.example/abc"}}, + assert: func(t *testing.T, client *fakeNotetakerClient, resp rpcTestResponse) { + requireNoRPCError(t, resp) + if client.grantID != "default-grant" { + t.Fatalf("grantID = %q, want default-grant", client.grantID) + } + if client.createReq == nil || client.createReq.MeetingLink != "https://meet.example/abc" || client.createReq.JoinTime != 1710000000 { + t.Fatalf("createReq = %#v, want meeting link and join time", client.createReq) + } + if client.createReq.BotConfig == nil || client.createReq.BotConfig.Name != "Nyla" { + t.Fatalf("BotConfig = %#v, want Nyla", client.createReq.BotConfig) + } + + var nt domain.Notetaker + unmarshalResult(t, resp, &nt) + if nt.ID != "nt-1" { + t.Fatalf("notetaker = %#v, want nt-1", nt) + } + }, + }, + { + name: "notetaker.create missing grant returns invalid params", + method: "notetaker.create", + params: `{"meeting_link":"https://meet.example/abc"}`, + client: &fakeNotetakerClient{}, + assert: func(t *testing.T, client *fakeNotetakerClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + if client.createReq != nil { + t.Fatalf("createReq = %#v, want nil", client.createReq) + } + }, + }, + { + name: "notetaker.create client error maps to internal error", + method: "notetaker.create", + params: `{"meeting_link":"https://meet.example/abc"}`, + defaultGrant: "default-grant", + client: &fakeNotetakerClient{err: clientErr}, + assert: func(t *testing.T, client *fakeNotetakerClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InternalError) + }, + }, + { + name: "notetaker.update forwards request", + method: "notetaker.update", + params: `{"notetaker_id":"nt-1","join_time":1710000100,"name":"Updated","meeting_settings":{"audio_recording":true}}`, + defaultGrant: "default-grant", + client: &fakeNotetakerClient{updateResult: &domain.Notetaker{ID: "nt-1", State: domain.NotetakerStateScheduled}}, + assert: func(t *testing.T, client *fakeNotetakerClient, resp rpcTestResponse) { + requireNoRPCError(t, resp) + if client.grantID != "default-grant" || client.notetakerID != "nt-1" { + t.Fatalf("args = %q/%q, want default-grant/nt-1", client.grantID, client.notetakerID) + } + if client.updateReq == nil || client.updateReq.JoinTime != 1710000100 || client.updateReq.Name != "Updated" { + t.Fatalf("updateReq = %#v, want join time and name", client.updateReq) + } + if client.updateReq.MeetingSettings == nil || client.updateReq.MeetingSettings.AudioRecording == nil || *client.updateReq.MeetingSettings.AudioRecording != enabled { + t.Fatalf("MeetingSettings = %#v, want audio recording true", client.updateReq.MeetingSettings) + } + }, + }, + { + name: "notetaker.update missing grant returns invalid params", + method: "notetaker.update", + params: `{"notetaker_id":"nt-1","name":"Updated"}`, + client: &fakeNotetakerClient{}, + assert: func(t *testing.T, client *fakeNotetakerClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + if client.updateReq != nil { + t.Fatalf("updateReq = %#v, want nil", client.updateReq) + } + }, + }, + { + name: "notetaker.update missing notetaker_id returns invalid params", + method: "notetaker.update", + params: `{"grant_id":"grant-1","name":"Updated"}`, + defaultGrant: "default-grant", + client: &fakeNotetakerClient{}, + assert: func(t *testing.T, client *fakeNotetakerClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + if client.updateReq != nil { + t.Fatalf("updateReq = %#v, want nil", client.updateReq) + } + }, + }, + { + name: "notetaker.update client error maps to internal error", + method: "notetaker.update", + params: `{"notetaker_id":"nt-1","name":"Updated"}`, + defaultGrant: "default-grant", + client: &fakeNotetakerClient{err: clientErr}, + assert: func(t *testing.T, client *fakeNotetakerClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InternalError) + }, + }, + { + name: "notetaker.delete returns deleted", + method: "notetaker.delete", + params: `{"grant_id":"grant-1","notetaker_id":"nt-1"}`, + client: &fakeNotetakerClient{}, + assert: func(t *testing.T, client *fakeNotetakerClient, resp rpcTestResponse) { + requireNoRPCError(t, resp) + if !client.deleteCalled || client.grantID != "grant-1" || client.notetakerID != "nt-1" { + t.Fatalf("delete = %v %q/%q, want true grant-1/nt-1", client.deleteCalled, client.grantID, client.notetakerID) + } + var result deletedResult + unmarshalResult(t, resp, &result) + if !result.Deleted { + t.Fatal("deleted = false, want true") + } + }, + }, + { + name: "notetaker.delete missing grant returns invalid params", + method: "notetaker.delete", + params: `{"notetaker_id":"nt-1"}`, + client: &fakeNotetakerClient{}, + assert: func(t *testing.T, client *fakeNotetakerClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + if client.deleteCalled { + t.Fatal("deleteCalled = true, want false") + } + }, + }, + { + name: "notetaker.delete missing notetaker_id returns invalid params", + method: "notetaker.delete", + params: `{"grant_id":"grant-1"}`, + defaultGrant: "default-grant", + client: &fakeNotetakerClient{}, + assert: func(t *testing.T, client *fakeNotetakerClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + if client.deleteCalled { + t.Fatal("deleteCalled = true, want false") + } + }, + }, + { + name: "notetaker.delete client error maps to internal error", + method: "notetaker.delete", + params: `{"notetaker_id":"nt-1"}`, + defaultGrant: "default-grant", + client: &fakeNotetakerClient{err: clientErr}, + assert: func(t *testing.T, client *fakeNotetakerClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InternalError) + }, + }, + { + name: "notetaker.leave returns left", + method: "notetaker.leave", + params: `{"grant_id":"grant-1","notetaker_id":"nt-1"}`, + client: &fakeNotetakerClient{}, + assert: func(t *testing.T, client *fakeNotetakerClient, resp rpcTestResponse) { + requireNoRPCError(t, resp) + if !client.leaveCalled || client.grantID != "grant-1" || client.notetakerID != "nt-1" { + t.Fatalf("leave = %v %q/%q, want true grant-1/nt-1", client.leaveCalled, client.grantID, client.notetakerID) + } + var result leftResult + unmarshalResult(t, resp, &result) + if !result.Left { + t.Fatal("left = false, want true") + } + }, + }, + { + name: "notetaker.leave missing grant returns invalid params", + method: "notetaker.leave", + params: `{"notetaker_id":"nt-1"}`, + client: &fakeNotetakerClient{}, + assert: func(t *testing.T, client *fakeNotetakerClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + if client.leaveCalled { + t.Fatal("leaveCalled = true, want false") + } + }, + }, + { + name: "notetaker.leave missing notetaker_id returns invalid params", + method: "notetaker.leave", + params: `{"grant_id":"grant-1"}`, + defaultGrant: "default-grant", + client: &fakeNotetakerClient{}, + assert: func(t *testing.T, client *fakeNotetakerClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + if client.leaveCalled { + t.Fatal("leaveCalled = true, want false") + } + }, + }, + { + name: "notetaker.leave client error maps to internal error", + method: "notetaker.leave", + params: `{"notetaker_id":"nt-1"}`, + defaultGrant: "default-grant", + client: &fakeNotetakerClient{err: clientErr}, + assert: func(t *testing.T, client *fakeNotetakerClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InternalError) + }, + }, + { + name: "notetaker.media returns media", + method: "notetaker.media", + params: `{"grant_id":"grant-1","notetaker_id":"nt-1"}`, + client: &fakeNotetakerClient{ + mediaResult: &domain.MediaData{Recording: &domain.MediaFile{URL: "https://files.example/rec.mp4"}}, + }, + assert: func(t *testing.T, client *fakeNotetakerClient, resp rpcTestResponse) { + requireNoRPCError(t, resp) + if client.grantID != "grant-1" || client.notetakerID != "nt-1" { + t.Fatalf("args = %q/%q, want grant-1/nt-1", client.grantID, client.notetakerID) + } + var media domain.MediaData + unmarshalResult(t, resp, &media) + if media.Recording == nil || media.Recording.URL != "https://files.example/rec.mp4" { + t.Fatalf("media = %#v, want recording URL", media) + } + }, + }, + { + name: "notetaker.media missing grant returns invalid params", + method: "notetaker.media", + params: `{"notetaker_id":"nt-1"}`, + client: &fakeNotetakerClient{}, + assert: func(t *testing.T, client *fakeNotetakerClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + if client.method != "" { + t.Fatalf("method = %q, want no client call", client.method) + } + }, + }, + { + name: "notetaker.media missing notetaker_id returns invalid params", + method: "notetaker.media", + params: `{"grant_id":"grant-1"}`, + defaultGrant: "default-grant", + client: &fakeNotetakerClient{}, + assert: func(t *testing.T, client *fakeNotetakerClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + if client.method != "" { + t.Fatalf("method = %q, want no client call", client.method) + } + }, + }, + { + name: "notetaker.media client error maps to internal error", + method: "notetaker.media", + params: `{"notetaker_id":"nt-1"}`, + defaultGrant: "default-grant", + client: &fakeNotetakerClient{err: clientErr}, + assert: func(t *testing.T, client *fakeNotetakerClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InternalError) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewDispatcher() + RegisterNotetakerHandlers(d, tt.client, tt.defaultGrant) + + resp := dispatchEmailRequest(t, d, tt.method, tt.params) + tt.assert(t, tt.client, resp) + }) + } +} diff --git a/internal/adapters/rpcserver/handlers_otp.go b/internal/adapters/rpcserver/handlers_otp.go new file mode 100644 index 0000000..b890b1a --- /dev/null +++ b/internal/adapters/rpcserver/handlers_otp.go @@ -0,0 +1,41 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/nylas/cli/internal/domain" +) + +type otpGetParams struct { + Email string `json:"email"` +} + +type otpService interface { + GetOTP(ctx context.Context, email string) (*domain.OTPResult, error) + GetOTPDefault(ctx context.Context) (*domain.OTPResult, error) +} + +func RegisterOTPHandlers(d *Dispatcher, svc otpService) { + d.Register("otp.get", func(ctx context.Context, params json.RawMessage) (any, error) { + var p otpGetParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + var ( + result *domain.OTPResult + err error + ) + if p.Email != "" { + result, err = svc.GetOTP(ctx, p.Email) + } else { + result, err = svc.GetOTPDefault(ctx) + } + if err != nil { + return nil, fmt.Errorf("otp.get: %w", err) + } + return result, nil + }) +} diff --git a/internal/adapters/rpcserver/handlers_otp_test.go b/internal/adapters/rpcserver/handlers_otp_test.go new file mode 100644 index 0000000..e4230ae --- /dev/null +++ b/internal/adapters/rpcserver/handlers_otp_test.go @@ -0,0 +1,102 @@ +package rpcserver + +import ( + "context" + "errors" + "testing" + + "github.com/nylas/cli/internal/domain" +) + +type fakeOTPService struct { + calledGetOTP bool + calledGetOTPDefault bool + email string + result *domain.OTPResult + err error +} + +func (f *fakeOTPService) GetOTP(_ context.Context, email string) (*domain.OTPResult, error) { + f.calledGetOTP = true + f.email = email + return f.result, f.err +} + +func (f *fakeOTPService) GetOTPDefault(_ context.Context) (*domain.OTPResult, error) { + f.calledGetOTPDefault = true + return f.result, f.err +} + +func TestRegisterOTPHandlers(t *testing.T) { + tests := []struct { + name string + params string + svc *fakeOTPService + assert func(*testing.T, *fakeOTPService, rpcTestResponse) + }{ + { + name: "otp.get with email routes to GetOTP", + params: `{"email":"user@example.com"}`, + svc: &fakeOTPService{ + result: &domain.OTPResult{Code: "123456", From: "sender@example.com", Subject: "Your code", MessageID: "msg-1"}, + }, + assert: func(t *testing.T, svc *fakeOTPService, resp rpcTestResponse) { + requireNoRPCError(t, resp) + if !svc.calledGetOTP || svc.calledGetOTPDefault { + t.Fatalf("calledGetOTP = %t, calledGetOTPDefault = %t; want GetOTP only", svc.calledGetOTP, svc.calledGetOTPDefault) + } + if svc.email != "user@example.com" { + t.Fatalf("email = %q, want user@example.com", svc.email) + } + + var result domain.OTPResult + unmarshalResult(t, resp, &result) + if result.Code != "123456" || result.From != "sender@example.com" || result.Subject != "Your code" || result.MessageID != "msg-1" { + t.Fatalf("result = %#v, want returned OTP result", result) + } + }, + }, + { + name: "otp.get without email routes to GetOTPDefault", + params: `{}`, + svc: &fakeOTPService{ + result: &domain.OTPResult{Code: "654321", From: "default@example.com", Subject: "Default code", MessageID: "msg-2"}, + }, + assert: func(t *testing.T, svc *fakeOTPService, resp rpcTestResponse) { + requireNoRPCError(t, resp) + if svc.calledGetOTP || !svc.calledGetOTPDefault { + t.Fatalf("calledGetOTP = %t, calledGetOTPDefault = %t; want GetOTPDefault only", svc.calledGetOTP, svc.calledGetOTPDefault) + } + + var result domain.OTPResult + unmarshalResult(t, resp, &result) + if result.Code != "654321" || result.From != "default@example.com" || result.Subject != "Default code" || result.MessageID != "msg-2" { + t.Fatalf("result = %#v, want returned OTP result", result) + } + }, + }, + { + name: "otp.get service error returns internal error", + params: `{"email":"user@example.com"}`, + svc: &fakeOTPService{ + err: errors.New("otp unavailable"), + }, + assert: func(t *testing.T, svc *fakeOTPService, resp rpcTestResponse) { + if !svc.calledGetOTP || svc.email != "user@example.com" { + t.Fatalf("calledGetOTP = %t, email = %q; want GetOTP with email", svc.calledGetOTP, svc.email) + } + requireRPCErrorCode(t, resp, InternalError) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewDispatcher() + RegisterOTPHandlers(d, tt.svc) + + resp := dispatchLocalRequest(t, d, "otp.get", tt.params) + tt.assert(t, tt.svc, resp) + }) + } +} diff --git a/internal/adapters/rpcserver/handlers_scheduler.go b/internal/adapters/rpcserver/handlers_scheduler.go new file mode 100644 index 0000000..fe8d4b1 --- /dev/null +++ b/internal/adapters/rpcserver/handlers_scheduler.go @@ -0,0 +1,372 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type schedulerConfigListResult struct { + Configurations []domain.SchedulerConfiguration `json:"configurations"` +} + +type schedulerConfigGetParams struct { + ConfigID string `json:"config_id"` +} + +type schedulerConfigCreateParams struct { + domain.CreateSchedulerConfigurationRequest +} + +type schedulerConfigUpdateParams struct { + ConfigID string `json:"config_id"` + domain.UpdateSchedulerConfigurationRequest +} + +type schedulerSessionCreateParams struct { + domain.CreateSchedulerSessionRequest +} + +type schedulerSessionGetParams struct { + SessionID string `json:"session_id"` +} + +type schedulerBookingGetParams struct { + BookingID string `json:"booking_id"` +} + +type schedulerBookingConfirmParams struct { + BookingID string `json:"booking_id"` + domain.ConfirmBookingRequest +} + +type schedulerBookingRescheduleParams struct { + BookingID string `json:"booking_id"` + domain.RescheduleBookingRequest +} + +type schedulerBookingCancelParams struct { + BookingID string `json:"booking_id"` + Reason string `json:"reason,omitempty"` +} + +type schedulerBookingCancelResult struct { + Cancelled bool `json:"cancelled"` +} + +type schedulerGroupEventListParams struct { + GrantID string `json:"grant_id,omitempty"` + ConfigID string `json:"config_id"` + CalendarID string `json:"calendar_id"` + StartTime int64 `json:"start_time,omitempty"` + EndTime int64 `json:"end_time,omitempty"` +} + +type schedulerGroupEventCreateParams struct { + GrantID string `json:"grant_id,omitempty"` + ConfigID string `json:"config_id"` + domain.CreateGroupEventRequest +} + +type schedulerGroupEventUpdateParams struct { + GrantID string `json:"grant_id,omitempty"` + ConfigID string `json:"config_id"` + EventID string `json:"event_id"` + domain.UpdateGroupEventRequest +} + +type schedulerGroupEventDeleteParams struct { + GrantID string `json:"grant_id,omitempty"` + ConfigID string `json:"config_id"` + EventID string `json:"event_id"` +} + +type schedulerGroupEventImportParams struct { + ConfigID string `json:"config_id"` + Items []domain.ImportGroupEventItem `json:"items"` +} + +type schedulerGroupEventResult struct { + Events []domain.GroupEvent `json:"events"` +} + +func RegisterSchedulerHandlers(d *Dispatcher, client ports.SchedulerClient, defaultGrant string) { + d.Register("scheduler.config.list", func(ctx context.Context, params json.RawMessage) (any, error) { + var p struct{} + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + configurations, err := client.ListSchedulerConfigurations(ctx) + if err != nil { + return nil, fmt.Errorf("scheduler.config.list: %w", err) + } + return schedulerConfigListResult{Configurations: configurations}, nil + }) + + d.Register("scheduler.config.get", func(ctx context.Context, params json.RawMessage) (any, error) { + var p schedulerConfigGetParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.ConfigID == "" { + return nil, NewRPCError(InvalidParams, "config_id required", nil) + } + + config, err := client.GetSchedulerConfiguration(ctx, p.ConfigID) + if err != nil { + return nil, fmt.Errorf("scheduler.config.get: %w", err) + } + return config, nil + }) + + d.Register("scheduler.config.create", func(ctx context.Context, params json.RawMessage) (any, error) { + var p schedulerConfigCreateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + config, err := client.CreateSchedulerConfiguration(ctx, &p.CreateSchedulerConfigurationRequest) + if err != nil { + return nil, fmt.Errorf("scheduler.config.create: %w", err) + } + return config, nil + }) + + d.Register("scheduler.config.update", func(ctx context.Context, params json.RawMessage) (any, error) { + var p schedulerConfigUpdateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.ConfigID == "" { + return nil, NewRPCError(InvalidParams, "config_id required", nil) + } + + config, err := client.UpdateSchedulerConfiguration(ctx, p.ConfigID, &p.UpdateSchedulerConfigurationRequest) + if err != nil { + return nil, fmt.Errorf("scheduler.config.update: %w", err) + } + return config, nil + }) + + d.Register("scheduler.config.delete", func(ctx context.Context, params json.RawMessage) (any, error) { + var p schedulerConfigGetParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.ConfigID == "" { + return nil, NewRPCError(InvalidParams, "config_id required", nil) + } + + if err := client.DeleteSchedulerConfiguration(ctx, p.ConfigID); err != nil { + return nil, fmt.Errorf("scheduler.config.delete: %w", err) + } + return deletedResult{Deleted: true}, nil + }) + + d.Register("scheduler.session.create", func(ctx context.Context, params json.RawMessage) (any, error) { + var p schedulerSessionCreateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + session, err := client.CreateSchedulerSession(ctx, &p.CreateSchedulerSessionRequest) + if err != nil { + return nil, fmt.Errorf("scheduler.session.create: %w", err) + } + return session, nil + }) + + d.Register("scheduler.session.get", func(ctx context.Context, params json.RawMessage) (any, error) { + var p schedulerSessionGetParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.SessionID == "" { + return nil, NewRPCError(InvalidParams, "session_id required", nil) + } + + session, err := client.GetSchedulerSession(ctx, p.SessionID) + if err != nil { + return nil, fmt.Errorf("scheduler.session.get: %w", err) + } + return session, nil + }) + + d.Register("scheduler.booking.get", func(ctx context.Context, params json.RawMessage) (any, error) { + var p schedulerBookingGetParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.BookingID == "" { + return nil, NewRPCError(InvalidParams, "booking_id required", nil) + } + + booking, err := client.GetBooking(ctx, p.BookingID) + if err != nil { + return nil, fmt.Errorf("scheduler.booking.get: %w", err) + } + return booking, nil + }) + + d.Register("scheduler.booking.confirm", func(ctx context.Context, params json.RawMessage) (any, error) { + var p schedulerBookingConfirmParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.BookingID == "" { + return nil, NewRPCError(InvalidParams, "booking_id required", nil) + } + + booking, err := client.ConfirmBooking(ctx, p.BookingID, &p.ConfirmBookingRequest) + if err != nil { + return nil, fmt.Errorf("scheduler.booking.confirm: %w", err) + } + return booking, nil + }) + + d.Register("scheduler.booking.reschedule", func(ctx context.Context, params json.RawMessage) (any, error) { + var p schedulerBookingRescheduleParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.BookingID == "" { + return nil, NewRPCError(InvalidParams, "booking_id required", nil) + } + + booking, err := client.RescheduleBooking(ctx, p.BookingID, &p.RescheduleBookingRequest) + if err != nil { + return nil, fmt.Errorf("scheduler.booking.reschedule: %w", err) + } + return booking, nil + }) + + d.Register("scheduler.booking.cancel", func(ctx context.Context, params json.RawMessage) (any, error) { + var p schedulerBookingCancelParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.BookingID == "" { + return nil, NewRPCError(InvalidParams, "booking_id required", nil) + } + + if err := client.CancelBooking(ctx, p.BookingID, p.Reason); err != nil { + return nil, fmt.Errorf("scheduler.booking.cancel: %w", err) + } + return schedulerBookingCancelResult{Cancelled: true}, nil + }) + + d.Register("scheduler.groupEvent.list", func(ctx context.Context, params json.RawMessage) (any, error) { + var p schedulerGroupEventListParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.ConfigID == "" { + return nil, NewRPCError(InvalidParams, "config_id required", nil) + } + if p.CalendarID == "" { + return nil, NewRPCError(InvalidParams, "calendar_id required", nil) + } + if p.StartTime <= 0 || p.EndTime <= 0 { + return nil, NewRPCError(InvalidParams, "start_time and end_time required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + events, err := client.ListGroupEvents(ctx, grantID, p.ConfigID, p.CalendarID, p.StartTime, p.EndTime) + if err != nil { + return nil, fmt.Errorf("scheduler.groupEvent.list: %w", err) + } + return schedulerGroupEventResult{Events: events}, nil + }) + + d.Register("scheduler.groupEvent.create", func(ctx context.Context, params json.RawMessage) (any, error) { + var p schedulerGroupEventCreateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.ConfigID == "" { + return nil, NewRPCError(InvalidParams, "config_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + events, err := client.CreateGroupEvent(ctx, grantID, p.ConfigID, &p.CreateGroupEventRequest) + if err != nil { + return nil, fmt.Errorf("scheduler.groupEvent.create: %w", err) + } + return schedulerGroupEventResult{Events: events}, nil + }) + + d.Register("scheduler.groupEvent.update", func(ctx context.Context, params json.RawMessage) (any, error) { + var p schedulerGroupEventUpdateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.ConfigID == "" { + return nil, NewRPCError(InvalidParams, "config_id required", nil) + } + if p.EventID == "" { + return nil, NewRPCError(InvalidParams, "event_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + events, err := client.UpdateGroupEvent(ctx, grantID, p.ConfigID, p.EventID, &p.UpdateGroupEventRequest) + if err != nil { + return nil, fmt.Errorf("scheduler.groupEvent.update: %w", err) + } + return schedulerGroupEventResult{Events: events}, nil + }) + + d.Register("scheduler.groupEvent.delete", func(ctx context.Context, params json.RawMessage) (any, error) { + var p schedulerGroupEventDeleteParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.ConfigID == "" { + return nil, NewRPCError(InvalidParams, "config_id required", nil) + } + if p.EventID == "" { + return nil, NewRPCError(InvalidParams, "event_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + if err := client.DeleteGroupEvent(ctx, grantID, p.ConfigID, p.EventID); err != nil { + return nil, fmt.Errorf("scheduler.groupEvent.delete: %w", err) + } + return deletedResult{Deleted: true}, nil + }) + + d.Register("scheduler.groupEvent.import", func(ctx context.Context, params json.RawMessage) (any, error) { + var p schedulerGroupEventImportParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.ConfigID == "" { + return nil, NewRPCError(InvalidParams, "config_id required", nil) + } + + events, err := client.ImportGroupEvents(ctx, p.ConfigID, p.Items) + if err != nil { + return nil, fmt.Errorf("scheduler.groupEvent.import: %w", err) + } + return schedulerGroupEventResult{Events: events}, nil + }) +} diff --git a/internal/adapters/rpcserver/handlers_scheduler_test.go b/internal/adapters/rpcserver/handlers_scheduler_test.go new file mode 100644 index 0000000..dbf34d0 --- /dev/null +++ b/internal/adapters/rpcserver/handlers_scheduler_test.go @@ -0,0 +1,277 @@ +package rpcserver + +import ( + "context" + "errors" + "testing" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type fakeSchedulerClient struct { + ports.SchedulerClient + + listSchedulerConfigurations func(context.Context) ([]domain.SchedulerConfiguration, error) + getSchedulerConfiguration func(context.Context, string) (*domain.SchedulerConfiguration, error) + getSchedulerSession func(context.Context, string) (*domain.SchedulerSession, error) + getBooking func(context.Context, string) (*domain.Booking, error) + listGroupEvents func(context.Context, string, string, string, int64, int64) ([]domain.GroupEvent, error) + + configID string + sessionID string + bookingID string + grantID string + calendarID string + startTime int64 + endTime int64 +} + +func (f *fakeSchedulerClient) ListSchedulerConfigurations(ctx context.Context) ([]domain.SchedulerConfiguration, error) { + if f.listSchedulerConfigurations == nil { + return nil, errors.New("unexpected ListSchedulerConfigurations") + } + return f.listSchedulerConfigurations(ctx) +} + +func (f *fakeSchedulerClient) GetSchedulerConfiguration(ctx context.Context, configID string) (*domain.SchedulerConfiguration, error) { + f.configID = configID + if f.getSchedulerConfiguration == nil { + return nil, errors.New("unexpected GetSchedulerConfiguration") + } + return f.getSchedulerConfiguration(ctx, configID) +} + +func (f *fakeSchedulerClient) GetSchedulerSession(ctx context.Context, sessionID string) (*domain.SchedulerSession, error) { + f.sessionID = sessionID + if f.getSchedulerSession == nil { + return nil, errors.New("unexpected GetSchedulerSession") + } + return f.getSchedulerSession(ctx, sessionID) +} + +func (f *fakeSchedulerClient) GetBooking(ctx context.Context, bookingID string) (*domain.Booking, error) { + f.bookingID = bookingID + if f.getBooking == nil { + return nil, errors.New("unexpected GetBooking") + } + return f.getBooking(ctx, bookingID) +} + +func (f *fakeSchedulerClient) ListGroupEvents(ctx context.Context, grantID, configID, calendarID string, startTime, endTime int64) ([]domain.GroupEvent, error) { + f.grantID = grantID + f.configID = configID + f.calendarID = calendarID + f.startTime = startTime + f.endTime = endTime + if f.listGroupEvents == nil { + return nil, errors.New("unexpected ListGroupEvents") + } + return f.listGroupEvents(ctx, grantID, configID, calendarID, startTime, endTime) +} + +func TestRegisterSchedulerHandlers(t *testing.T) { + clientErr := errors.New("client unavailable") + + tests := []struct { + name string + method string + params string + defaultGrant string + client *fakeSchedulerClient + assert func(*testing.T, *fakeSchedulerClient, rpcTestResponse) + }{ + { + name: "scheduler.config.list returns configurations", + method: "scheduler.config.list", + params: `{}`, + client: &fakeSchedulerClient{ + listSchedulerConfigurations: func(ctx context.Context) ([]domain.SchedulerConfiguration, error) { + return []domain.SchedulerConfiguration{{ID: "config-1", Name: "Intro"}}, nil + }, + }, + assert: func(t *testing.T, client *fakeSchedulerClient, resp rpcTestResponse) { + requireNoRPCError(t, resp) + + var result schedulerConfigListResult + unmarshalResult(t, resp, &result) + if len(result.Configurations) != 1 || result.Configurations[0].ID != "config-1" { + t.Fatalf("configurations = %+v, want config-1", result.Configurations) + } + }, + }, + { + name: "scheduler.config.get requires config_id", + method: "scheduler.config.get", + params: `{}`, + client: &fakeSchedulerClient{}, + assert: func(t *testing.T, client *fakeSchedulerClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + if client.configID != "" { + t.Fatalf("GetSchedulerConfiguration called with config %q, want no call", client.configID) + } + }, + }, + { + name: "scheduler.config.get client error maps to internal error", + method: "scheduler.config.get", + params: `{"config_id":"config-1"}`, + client: &fakeSchedulerClient{ + getSchedulerConfiguration: func(ctx context.Context, configID string) (*domain.SchedulerConfiguration, error) { + return nil, clientErr + }, + }, + assert: func(t *testing.T, client *fakeSchedulerClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InternalError) + if client.configID != "config-1" { + t.Fatalf("configID = %q, want config-1", client.configID) + } + }, + }, + { + name: "scheduler.session.get returns session", + method: "scheduler.session.get", + params: `{"session_id":"session-1"}`, + client: &fakeSchedulerClient{ + getSchedulerSession: func(ctx context.Context, sessionID string) (*domain.SchedulerSession, error) { + return &domain.SchedulerSession{SessionID: sessionID, ConfigurationID: "config-1"}, nil + }, + }, + assert: func(t *testing.T, client *fakeSchedulerClient, resp rpcTestResponse) { + requireNoRPCError(t, resp) + if client.sessionID != "session-1" { + t.Fatalf("sessionID = %q, want session-1", client.sessionID) + } + + var session domain.SchedulerSession + unmarshalResult(t, resp, &session) + if session.SessionID != "session-1" || session.ConfigurationID != "config-1" { + t.Fatalf("session = %+v, want session-1 config-1", session) + } + }, + }, + { + name: "scheduler.session.get requires session_id", + method: "scheduler.session.get", + params: `{}`, + client: &fakeSchedulerClient{}, + assert: func(t *testing.T, client *fakeSchedulerClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + if client.sessionID != "" { + t.Fatalf("GetSchedulerSession called with session %q, want no call", client.sessionID) + } + }, + }, + { + name: "scheduler.booking.get returns booking", + method: "scheduler.booking.get", + params: `{"booking_id":"booking-1"}`, + client: &fakeSchedulerClient{ + getBooking: func(ctx context.Context, bookingID string) (*domain.Booking, error) { + return &domain.Booking{BookingID: bookingID, Title: "Intro", Status: "confirmed"}, nil + }, + }, + assert: func(t *testing.T, client *fakeSchedulerClient, resp rpcTestResponse) { + requireNoRPCError(t, resp) + if client.bookingID != "booking-1" { + t.Fatalf("bookingID = %q, want booking-1", client.bookingID) + } + + var booking domain.Booking + unmarshalResult(t, resp, &booking) + if booking.BookingID != "booking-1" || booking.Title != "Intro" { + t.Fatalf("booking = %+v, want booking-1 Intro", booking) + } + }, + }, + { + name: "scheduler.booking.get requires booking_id", + method: "scheduler.booking.get", + params: `{}`, + client: &fakeSchedulerClient{}, + assert: func(t *testing.T, client *fakeSchedulerClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + if client.bookingID != "" { + t.Fatalf("GetBooking called with booking %q, want no call", client.bookingID) + } + }, + }, + { + name: "scheduler.groupEvent.list forwards grant and returns events", + method: "scheduler.groupEvent.list", + params: `{"grant_id":"grant-1","config_id":"config-1","calendar_id":"cal-1","start_time":1710000000,"end_time":1710003600}`, + client: &fakeSchedulerClient{ + listGroupEvents: func(ctx context.Context, grantID, configID, calendarID string, startTime, endTime int64) ([]domain.GroupEvent, error) { + return []domain.GroupEvent{{ID: "event-1", CalendarID: calendarID, Title: "Workshop"}}, nil + }, + }, + assert: func(t *testing.T, client *fakeSchedulerClient, resp rpcTestResponse) { + requireNoRPCError(t, resp) + if client.grantID != "grant-1" || client.configID != "config-1" || client.calendarID != "cal-1" { + t.Fatalf("args = %q %q %q, want grant-1 config-1 cal-1", client.grantID, client.configID, client.calendarID) + } + if client.startTime != 1710000000 || client.endTime != 1710003600 { + t.Fatalf("times = %d %d, want forwarded window", client.startTime, client.endTime) + } + + var result schedulerGroupEventResult + unmarshalResult(t, resp, &result) + if len(result.Events) != 1 || result.Events[0].ID != "event-1" { + t.Fatalf("events = %+v, want event-1", result.Events) + } + }, + }, + { + name: "scheduler.groupEvent.list requires calendar_id", + method: "scheduler.groupEvent.list", + params: `{"config_id":"config-1"}`, + defaultGrant: "default-grant", + client: &fakeSchedulerClient{}, + assert: func(t *testing.T, client *fakeSchedulerClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + if client.grantID != "" { + t.Fatalf("ListGroupEvents called with grant %q, want no call", client.grantID) + } + }, + }, + { + name: "scheduler.groupEvent.list requires time window", + method: "scheduler.groupEvent.list", + params: `{"config_id":"config-1","calendar_id":"cal-1"}`, + defaultGrant: "default-grant", + client: &fakeSchedulerClient{}, + assert: func(t *testing.T, client *fakeSchedulerClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + if client.grantID != "" { + t.Fatalf("ListGroupEvents called with grant %q, want no call", client.grantID) + } + }, + }, + { + name: "scheduler.groupEvent.list client error maps to internal error", + method: "scheduler.groupEvent.list", + params: `{"grant_id":"grant-1","config_id":"config-1","calendar_id":"cal-1","start_time":1710000000,"end_time":1710003600}`, + client: &fakeSchedulerClient{ + listGroupEvents: func(ctx context.Context, grantID, configID, calendarID string, startTime, endTime int64) ([]domain.GroupEvent, error) { + return nil, clientErr + }, + }, + assert: func(t *testing.T, client *fakeSchedulerClient, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InternalError) + if client.grantID != "grant-1" { + t.Fatalf("grantID = %q, want grant-1", client.grantID) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewDispatcher() + RegisterSchedulerHandlers(d, tt.client, tt.defaultGrant) + + resp := dispatchCalendarRequest(t, d, tt.method, tt.params) + tt.assert(t, tt.client, resp) + }) + } +} diff --git a/internal/adapters/rpcserver/handlers_templates.go b/internal/adapters/rpcserver/handlers_templates.go new file mode 100644 index 0000000..e61d6ed --- /dev/null +++ b/internal/adapters/rpcserver/handlers_templates.go @@ -0,0 +1,421 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type templateWorkflowListParams struct { + Scope string `json:"scope,omitempty"` + GrantID string `json:"grant_id,omitempty"` + domain.CursorListParams +} + +type templateListResult struct { + Templates []domain.RemoteTemplate `json:"templates"` + NextCursor string `json:"next_cursor,omitempty"` +} + +type templateGetParams struct { + Scope string `json:"scope,omitempty"` + GrantID string `json:"grant_id,omitempty"` + TemplateID string `json:"template_id"` +} + +type templateCreateParams struct { + Scope string `json:"scope,omitempty"` + GrantID string `json:"grant_id,omitempty"` + domain.CreateRemoteTemplateRequest +} + +type templateUpdateParams struct { + Scope string `json:"scope,omitempty"` + GrantID string `json:"grant_id,omitempty"` + TemplateID string `json:"template_id"` + domain.UpdateRemoteTemplateRequest +} + +type templateDeleteParams struct { + Scope string `json:"scope,omitempty"` + GrantID string `json:"grant_id,omitempty"` + TemplateID string `json:"template_id"` +} + +type templateRenderParams struct { + Scope string `json:"scope,omitempty"` + GrantID string `json:"grant_id,omitempty"` + TemplateID string `json:"template_id"` + domain.TemplateRenderRequest +} + +type templateRenderHTMLParams struct { + Scope string `json:"scope,omitempty"` + GrantID string `json:"grant_id,omitempty"` + domain.TemplateRenderHTMLRequest +} + +type workflowListResult struct { + Workflows []domain.RemoteWorkflow `json:"workflows"` + NextCursor string `json:"next_cursor,omitempty"` +} + +type workflowGetParams struct { + Scope string `json:"scope,omitempty"` + GrantID string `json:"grant_id,omitempty"` + WorkflowID string `json:"workflow_id"` +} + +type workflowCreateParams struct { + Scope string `json:"scope,omitempty"` + GrantID string `json:"grant_id,omitempty"` + domain.CreateRemoteWorkflowRequest +} + +type workflowUpdateParams struct { + Scope string `json:"scope,omitempty"` + GrantID string `json:"grant_id,omitempty"` + WorkflowID string `json:"workflow_id"` + domain.UpdateRemoteWorkflowRequest +} + +type workflowDeleteParams struct { + Scope string `json:"scope,omitempty"` + GrantID string `json:"grant_id,omitempty"` + WorkflowID string `json:"workflow_id"` +} + +func RegisterTemplateWorkflowHandlers(d *Dispatcher, client ports.TemplateWorkflowClient, defaultGrant string) { + d.Register("template.list", func(ctx context.Context, params json.RawMessage) (any, error) { + var p templateWorkflowListParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + scope, err := parseRPCRemoteScope(p.Scope) + if err != nil { + return nil, err + } + grantID := p.GrantID + if scope == domain.ScopeGrant { + grantID, err = resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + } + + resp, err := client.ListRemoteTemplates(ctx, scope, grantID, &p.CursorListParams) + if err != nil { + return nil, fmt.Errorf("template.list: %w", err) + } + return templateListResult{Templates: resp.Data, NextCursor: resp.NextCursor}, nil + }) + + d.Register("template.get", func(ctx context.Context, params json.RawMessage) (any, error) { + var p templateGetParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.TemplateID == "" { + return nil, NewRPCError(InvalidParams, "template_id required", nil) + } + + scope, err := parseRPCRemoteScope(p.Scope) + if err != nil { + return nil, err + } + grantID := p.GrantID + if scope == domain.ScopeGrant { + grantID, err = resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + } + + template, err := client.GetRemoteTemplate(ctx, scope, grantID, p.TemplateID) + if err != nil { + return nil, fmt.Errorf("template.get: %w", err) + } + return template, nil + }) + + d.Register("template.create", func(ctx context.Context, params json.RawMessage) (any, error) { + var p templateCreateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + scope, err := parseRPCRemoteScope(p.Scope) + if err != nil { + return nil, err + } + grantID := p.GrantID + if scope == domain.ScopeGrant { + grantID, err = resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + } + + template, err := client.CreateRemoteTemplate(ctx, scope, grantID, &p.CreateRemoteTemplateRequest) + if err != nil { + return nil, fmt.Errorf("template.create: %w", err) + } + return template, nil + }) + + d.Register("template.update", func(ctx context.Context, params json.RawMessage) (any, error) { + var p templateUpdateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.TemplateID == "" { + return nil, NewRPCError(InvalidParams, "template_id required", nil) + } + + scope, err := parseRPCRemoteScope(p.Scope) + if err != nil { + return nil, err + } + grantID := p.GrantID + if scope == domain.ScopeGrant { + grantID, err = resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + } + + template, err := client.UpdateRemoteTemplate(ctx, scope, grantID, p.TemplateID, &p.UpdateRemoteTemplateRequest) + if err != nil { + return nil, fmt.Errorf("template.update: %w", err) + } + return template, nil + }) + + d.Register("template.delete", func(ctx context.Context, params json.RawMessage) (any, error) { + var p templateDeleteParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.TemplateID == "" { + return nil, NewRPCError(InvalidParams, "template_id required", nil) + } + + scope, err := parseRPCRemoteScope(p.Scope) + if err != nil { + return nil, err + } + grantID := p.GrantID + if scope == domain.ScopeGrant { + grantID, err = resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + } + + if err := client.DeleteRemoteTemplate(ctx, scope, grantID, p.TemplateID); err != nil { + return nil, fmt.Errorf("template.delete: %w", err) + } + return deletedResult{Deleted: true}, nil + }) + + d.Register("template.render", func(ctx context.Context, params json.RawMessage) (any, error) { + var p templateRenderParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.TemplateID == "" { + return nil, NewRPCError(InvalidParams, "template_id required", nil) + } + + scope, err := parseRPCRemoteScope(p.Scope) + if err != nil { + return nil, err + } + grantID := p.GrantID + if scope == domain.ScopeGrant { + grantID, err = resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + } + + result, err := client.RenderRemoteTemplate(ctx, scope, grantID, p.TemplateID, &p.TemplateRenderRequest) + if err != nil { + return nil, fmt.Errorf("template.render: %w", err) + } + return result, nil + }) + + d.Register("template.renderHTML", func(ctx context.Context, params json.RawMessage) (any, error) { + var p templateRenderHTMLParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + scope, err := parseRPCRemoteScope(p.Scope) + if err != nil { + return nil, err + } + grantID := p.GrantID + if scope == domain.ScopeGrant { + grantID, err = resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + } + + result, err := client.RenderRemoteTemplateHTML(ctx, scope, grantID, &p.TemplateRenderHTMLRequest) + if err != nil { + return nil, fmt.Errorf("template.renderHTML: %w", err) + } + return result, nil + }) + + d.Register("workflow.list", func(ctx context.Context, params json.RawMessage) (any, error) { + var p templateWorkflowListParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + scope, err := parseRPCRemoteScope(p.Scope) + if err != nil { + return nil, err + } + grantID := p.GrantID + if scope == domain.ScopeGrant { + grantID, err = resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + } + + resp, err := client.ListWorkflows(ctx, scope, grantID, &p.CursorListParams) + if err != nil { + return nil, fmt.Errorf("workflow.list: %w", err) + } + return workflowListResult{Workflows: resp.Data, NextCursor: resp.NextCursor}, nil + }) + + d.Register("workflow.get", func(ctx context.Context, params json.RawMessage) (any, error) { + var p workflowGetParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.WorkflowID == "" { + return nil, NewRPCError(InvalidParams, "workflow_id required", nil) + } + + scope, err := parseRPCRemoteScope(p.Scope) + if err != nil { + return nil, err + } + grantID := p.GrantID + if scope == domain.ScopeGrant { + grantID, err = resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + } + + workflow, err := client.GetWorkflow(ctx, scope, grantID, p.WorkflowID) + if err != nil { + return nil, fmt.Errorf("workflow.get: %w", err) + } + return workflow, nil + }) + + d.Register("workflow.create", func(ctx context.Context, params json.RawMessage) (any, error) { + var p workflowCreateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + scope, err := parseRPCRemoteScope(p.Scope) + if err != nil { + return nil, err + } + grantID := p.GrantID + if scope == domain.ScopeGrant { + grantID, err = resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + } + + workflow, err := client.CreateWorkflow(ctx, scope, grantID, &p.CreateRemoteWorkflowRequest) + if err != nil { + return nil, fmt.Errorf("workflow.create: %w", err) + } + return workflow, nil + }) + + d.Register("workflow.update", func(ctx context.Context, params json.RawMessage) (any, error) { + var p workflowUpdateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.WorkflowID == "" { + return nil, NewRPCError(InvalidParams, "workflow_id required", nil) + } + + scope, err := parseRPCRemoteScope(p.Scope) + if err != nil { + return nil, err + } + grantID := p.GrantID + if scope == domain.ScopeGrant { + grantID, err = resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + } + + workflow, err := client.UpdateWorkflow(ctx, scope, grantID, p.WorkflowID, &p.UpdateRemoteWorkflowRequest) + if err != nil { + return nil, fmt.Errorf("workflow.update: %w", err) + } + return workflow, nil + }) + + d.Register("workflow.delete", func(ctx context.Context, params json.RawMessage) (any, error) { + var p workflowDeleteParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.WorkflowID == "" { + return nil, NewRPCError(InvalidParams, "workflow_id required", nil) + } + + scope, err := parseRPCRemoteScope(p.Scope) + if err != nil { + return nil, err + } + grantID := p.GrantID + if scope == domain.ScopeGrant { + grantID, err = resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + } + + if err := client.DeleteWorkflow(ctx, scope, grantID, p.WorkflowID); err != nil { + return nil, fmt.Errorf("workflow.delete: %w", err) + } + return deletedResult{Deleted: true}, nil + }) +} + +func parseRPCRemoteScope(scope string) (domain.RemoteScope, error) { + if scope == "" { + return domain.ScopeApplication, nil + } + parsed, err := domain.ParseRemoteScope(scope) + if err != nil { + return "", NewRPCError(InvalidParams, "invalid scope", nil) + } + return parsed, nil +} diff --git a/internal/adapters/rpcserver/handlers_templates_test.go b/internal/adapters/rpcserver/handlers_templates_test.go new file mode 100644 index 0000000..b69303b --- /dev/null +++ b/internal/adapters/rpcserver/handlers_templates_test.go @@ -0,0 +1,458 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "errors" + "strings" + "testing" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type fakeTemplateWorkflowClient struct { + ports.TemplateWorkflowClient + + listRemoteTemplates func(context.Context, domain.RemoteScope, string, *domain.CursorListParams) (*domain.RemoteTemplateListResponse, error) + getRemoteTemplate func(context.Context, domain.RemoteScope, string, string) (*domain.RemoteTemplate, error) + createRemoteTemplate func(context.Context, domain.RemoteScope, string, *domain.CreateRemoteTemplateRequest) (*domain.RemoteTemplate, error) + updateRemoteTemplate func(context.Context, domain.RemoteScope, string, string, *domain.UpdateRemoteTemplateRequest) (*domain.RemoteTemplate, error) + deleteRemoteTemplate func(context.Context, domain.RemoteScope, string, string) error + renderRemoteTemplate func(context.Context, domain.RemoteScope, string, string, *domain.TemplateRenderRequest) (domain.TemplateRenderResult, error) + renderRemoteTemplateHTML func(context.Context, domain.RemoteScope, string, *domain.TemplateRenderHTMLRequest) (domain.TemplateRenderResult, error) + listWorkflows func(context.Context, domain.RemoteScope, string, *domain.CursorListParams) (*domain.RemoteWorkflowListResponse, error) + getWorkflow func(context.Context, domain.RemoteScope, string, string) (*domain.RemoteWorkflow, error) + createWorkflow func(context.Context, domain.RemoteScope, string, *domain.CreateRemoteWorkflowRequest) (*domain.RemoteWorkflow, error) + updateWorkflow func(context.Context, domain.RemoteScope, string, string, *domain.UpdateRemoteWorkflowRequest) (*domain.RemoteWorkflow, error) + deleteWorkflow func(context.Context, domain.RemoteScope, string, string) error +} + +func (f *fakeTemplateWorkflowClient) ListRemoteTemplates(ctx context.Context, scope domain.RemoteScope, grantID string, params *domain.CursorListParams) (*domain.RemoteTemplateListResponse, error) { + if f.listRemoteTemplates == nil { + return nil, errors.New("unexpected ListRemoteTemplates") + } + return f.listRemoteTemplates(ctx, scope, grantID, params) +} + +func (f *fakeTemplateWorkflowClient) GetRemoteTemplate(ctx context.Context, scope domain.RemoteScope, grantID, templateID string) (*domain.RemoteTemplate, error) { + if f.getRemoteTemplate == nil { + return nil, errors.New("unexpected GetRemoteTemplate") + } + return f.getRemoteTemplate(ctx, scope, grantID, templateID) +} + +func (f *fakeTemplateWorkflowClient) CreateRemoteTemplate(ctx context.Context, scope domain.RemoteScope, grantID string, req *domain.CreateRemoteTemplateRequest) (*domain.RemoteTemplate, error) { + if f.createRemoteTemplate == nil { + return nil, errors.New("unexpected CreateRemoteTemplate") + } + return f.createRemoteTemplate(ctx, scope, grantID, req) +} + +func (f *fakeTemplateWorkflowClient) UpdateRemoteTemplate(ctx context.Context, scope domain.RemoteScope, grantID, templateID string, req *domain.UpdateRemoteTemplateRequest) (*domain.RemoteTemplate, error) { + if f.updateRemoteTemplate == nil { + return nil, errors.New("unexpected UpdateRemoteTemplate") + } + return f.updateRemoteTemplate(ctx, scope, grantID, templateID, req) +} + +func (f *fakeTemplateWorkflowClient) DeleteRemoteTemplate(ctx context.Context, scope domain.RemoteScope, grantID, templateID string) error { + if f.deleteRemoteTemplate == nil { + return errors.New("unexpected DeleteRemoteTemplate") + } + return f.deleteRemoteTemplate(ctx, scope, grantID, templateID) +} + +func (f *fakeTemplateWorkflowClient) RenderRemoteTemplate(ctx context.Context, scope domain.RemoteScope, grantID, templateID string, req *domain.TemplateRenderRequest) (domain.TemplateRenderResult, error) { + if f.renderRemoteTemplate == nil { + return nil, errors.New("unexpected RenderRemoteTemplate") + } + return f.renderRemoteTemplate(ctx, scope, grantID, templateID, req) +} + +func (f *fakeTemplateWorkflowClient) RenderRemoteTemplateHTML(ctx context.Context, scope domain.RemoteScope, grantID string, req *domain.TemplateRenderHTMLRequest) (domain.TemplateRenderResult, error) { + if f.renderRemoteTemplateHTML == nil { + return nil, errors.New("unexpected RenderRemoteTemplateHTML") + } + return f.renderRemoteTemplateHTML(ctx, scope, grantID, req) +} + +func (f *fakeTemplateWorkflowClient) ListWorkflows(ctx context.Context, scope domain.RemoteScope, grantID string, params *domain.CursorListParams) (*domain.RemoteWorkflowListResponse, error) { + if f.listWorkflows == nil { + return nil, errors.New("unexpected ListWorkflows") + } + return f.listWorkflows(ctx, scope, grantID, params) +} + +func (f *fakeTemplateWorkflowClient) GetWorkflow(ctx context.Context, scope domain.RemoteScope, grantID, workflowID string) (*domain.RemoteWorkflow, error) { + if f.getWorkflow == nil { + return nil, errors.New("unexpected GetWorkflow") + } + return f.getWorkflow(ctx, scope, grantID, workflowID) +} + +func (f *fakeTemplateWorkflowClient) CreateWorkflow(ctx context.Context, scope domain.RemoteScope, grantID string, req *domain.CreateRemoteWorkflowRequest) (*domain.RemoteWorkflow, error) { + if f.createWorkflow == nil { + return nil, errors.New("unexpected CreateWorkflow") + } + return f.createWorkflow(ctx, scope, grantID, req) +} + +func (f *fakeTemplateWorkflowClient) UpdateWorkflow(ctx context.Context, scope domain.RemoteScope, grantID, workflowID string, req *domain.UpdateRemoteWorkflowRequest) (*domain.RemoteWorkflow, error) { + if f.updateWorkflow == nil { + return nil, errors.New("unexpected UpdateWorkflow") + } + return f.updateWorkflow(ctx, scope, grantID, workflowID, req) +} + +func (f *fakeTemplateWorkflowClient) DeleteWorkflow(ctx context.Context, scope domain.RemoteScope, grantID, workflowID string) error { + if f.deleteWorkflow == nil { + return errors.New("unexpected DeleteWorkflow") + } + return f.deleteWorkflow(ctx, scope, grantID, workflowID) +} + +func TestRegisterTemplateWorkflowHandlers(t *testing.T) { + clientErr := errors.New("client unavailable") + enabled := true + + tests := []struct { + name string + method string + params string + defaultGrant string + client *fakeTemplateWorkflowClient + assert func(*testing.T, rpcTestResponse) + }{ + { + name: "template.list app scope succeeds without default grant", + method: "template.list", + params: `{"scope":"app","limit":2,"page_token":"cursor-1"}`, + client: &fakeTemplateWorkflowClient{listRemoteTemplates: func(ctx context.Context, scope domain.RemoteScope, grantID string, params *domain.CursorListParams) (*domain.RemoteTemplateListResponse, error) { + if scope != domain.ScopeApplication || grantID != "" || params.Limit != 2 || params.PageToken != "cursor-1" { + t.Fatalf("list args = %q %q %+v, want app empty grant cursor params", scope, grantID, params) + } + return &domain.RemoteTemplateListResponse{Data: []domain.RemoteTemplate{{ID: "tpl-1"}}, NextCursor: "cursor-2"}, nil + }}, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + var result templateListResult + unmarshalResult(t, resp, &result) + if len(result.Templates) != 1 || result.Templates[0].ID != "tpl-1" || result.NextCursor != "cursor-2" { + t.Fatalf("result = %+v, want tpl-1 cursor-2", result) + } + }, + }, + { + name: "template.list grant scope requires grant_id", + method: "template.list", + params: `{"scope":"grant"}`, + client: &fakeTemplateWorkflowClient{}, + assert: func(t *testing.T, resp rpcTestResponse) { + requireRPCErrorCode(t, resp, InvalidParams) + }, + }, + { + name: "template.get returns template", + method: "template.get", + params: `{"scope":"grant","grant_id":"grant-1","template_id":"tpl-1"}`, + defaultGrant: "default-grant", + client: &fakeTemplateWorkflowClient{getRemoteTemplate: func(ctx context.Context, scope domain.RemoteScope, grantID, templateID string) (*domain.RemoteTemplate, error) { + if scope != domain.ScopeGrant || grantID != "grant-1" || templateID != "tpl-1" { + t.Fatalf("get args = %q %q %q, want grant grant-1 tpl-1", scope, grantID, templateID) + } + return &domain.RemoteTemplate{ID: templateID, Name: "Welcome"}, nil + }}, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + var template domain.RemoteTemplate + unmarshalResult(t, resp, &template) + if template.ID != "tpl-1" || template.Name != "Welcome" { + t.Fatalf("template = %+v, want tpl-1 Welcome", template) + } + }, + }, + { + name: "template.create forwards request", + method: "template.create", + params: `{"grant_id":"grant-1","name":"Welcome","engine":"handlebars","subject":"Hi","body":"Hello {{name}}"}`, + defaultGrant: "default-grant", + client: &fakeTemplateWorkflowClient{createRemoteTemplate: func(ctx context.Context, scope domain.RemoteScope, grantID string, req *domain.CreateRemoteTemplateRequest) (*domain.RemoteTemplate, error) { + if scope != domain.ScopeApplication || grantID != "grant-1" || req.Name != "Welcome" || req.Body != "Hello {{name}}" { + t.Fatalf("create args = %q %q %+v, want forwarded template request", scope, grantID, req) + } + return &domain.RemoteTemplate{ID: "tpl-1", Name: req.Name}, nil + }}, + assert: requireNoRPCError, + }, + { + name: "template.update forwards request", + method: "template.update", + params: `{"template_id":"tpl-1","name":"Updated"}`, + defaultGrant: "default-grant", + client: &fakeTemplateWorkflowClient{updateRemoteTemplate: func(ctx context.Context, scope domain.RemoteScope, grantID, templateID string, req *domain.UpdateRemoteTemplateRequest) (*domain.RemoteTemplate, error) { + if scope != domain.ScopeApplication || grantID != "" || templateID != "tpl-1" || req.Name == nil || *req.Name != "Updated" { + t.Fatalf("update args = %q %q %q %+v, want forwarded update", scope, grantID, templateID, req) + } + return &domain.RemoteTemplate{ID: templateID, Name: *req.Name}, nil + }}, + assert: requireNoRPCError, + }, + { + name: "template.delete returns deleted", + method: "template.delete", + params: `{"template_id":"tpl-1"}`, + defaultGrant: "default-grant", + client: &fakeTemplateWorkflowClient{deleteRemoteTemplate: func(ctx context.Context, scope domain.RemoteScope, grantID, templateID string) error { + if scope != domain.ScopeApplication || grantID != "" || templateID != "tpl-1" { + t.Fatalf("delete args = %q %q %q, want app empty grant tpl-1", scope, grantID, templateID) + } + return nil + }}, + assert: assertDeleted, + }, + { + name: "template.render returns result", + method: "template.render", + params: `{"template_id":"tpl-1","variables":{"name":"Ada"},"strict":true}`, + defaultGrant: "default-grant", + client: &fakeTemplateWorkflowClient{renderRemoteTemplate: func(ctx context.Context, scope domain.RemoteScope, grantID, templateID string, req *domain.TemplateRenderRequest) (domain.TemplateRenderResult, error) { + if scope != domain.ScopeApplication || grantID != "" || templateID != "tpl-1" || req.Strict == nil || !*req.Strict || req.Variables["name"] != "Ada" { + t.Fatalf("render args = %q %q %q %+v, want forwarded render request", scope, grantID, templateID, req) + } + return domain.TemplateRenderResult{"body": "Hello Ada"}, nil + }}, + assert: assertRenderBody("Hello Ada"), + }, + { + name: "template.renderHTML returns result", + method: "template.renderHTML", + params: `{"body":"Hello {{name}}","engine":"handlebars","variables":{"name":"Ada"}}`, + defaultGrant: "default-grant", + client: &fakeTemplateWorkflowClient{renderRemoteTemplateHTML: func(ctx context.Context, scope domain.RemoteScope, grantID string, req *domain.TemplateRenderHTMLRequest) (domain.TemplateRenderResult, error) { + if scope != domain.ScopeApplication || grantID != "" || req.Body != "Hello {{name}}" || req.Engine != "handlebars" || req.Variables["name"] != "Ada" { + t.Fatalf("renderHTML args = %q %q %+v, want forwarded render HTML request", scope, grantID, req) + } + return domain.TemplateRenderResult{"body": "Hello Ada"}, nil + }}, + assert: assertRenderBody("Hello Ada"), + }, + { + name: "workflow.list returns workflows", + method: "workflow.list", + params: `{"limit":2}`, + defaultGrant: "default-grant", + client: &fakeTemplateWorkflowClient{listWorkflows: func(ctx context.Context, scope domain.RemoteScope, grantID string, params *domain.CursorListParams) (*domain.RemoteWorkflowListResponse, error) { + if scope != domain.ScopeApplication || grantID != "" || params.Limit != 2 { + t.Fatalf("workflow list args = %q %q %+v, want app empty grant limit 2", scope, grantID, params) + } + return &domain.RemoteWorkflowListResponse{Data: []domain.RemoteWorkflow{{ID: "wf-1"}}, NextCursor: "cursor-2"}, nil + }}, + assert: func(t *testing.T, resp rpcTestResponse) { + requireNoRPCError(t, resp) + var result workflowListResult + unmarshalResult(t, resp, &result) + if len(result.Workflows) != 1 || result.Workflows[0].ID != "wf-1" || result.NextCursor != "cursor-2" { + t.Fatalf("result = %+v, want wf-1 cursor-2", result) + } + }, + }, + { + name: "workflow.get returns workflow", + method: "workflow.get", + params: `{"scope":"grant","grant_id":"grant-1","workflow_id":"wf-1"}`, + defaultGrant: "default-grant", + client: &fakeTemplateWorkflowClient{getWorkflow: func(ctx context.Context, scope domain.RemoteScope, grantID, workflowID string) (*domain.RemoteWorkflow, error) { + if scope != domain.ScopeGrant || grantID != "grant-1" || workflowID != "wf-1" { + t.Fatalf("workflow get args = %q %q %q, want grant grant-1 wf-1", scope, grantID, workflowID) + } + return &domain.RemoteWorkflow{ID: workflowID, Name: "Reminder"}, nil + }}, + assert: requireNoRPCError, + }, + { + name: "workflow.create forwards request", + method: "workflow.create", + params: `{"name":"Reminder","template_id":"tpl-1","trigger_event":"booking.created","is_enabled":true}`, + defaultGrant: "default-grant", + client: &fakeTemplateWorkflowClient{createWorkflow: func(ctx context.Context, scope domain.RemoteScope, grantID string, req *domain.CreateRemoteWorkflowRequest) (*domain.RemoteWorkflow, error) { + if scope != domain.ScopeApplication || grantID != "" || req.Name != "Reminder" || req.TemplateID != "tpl-1" || req.IsEnabled == nil || !*req.IsEnabled { + t.Fatalf("workflow create args = %q %q %+v, want forwarded request", scope, grantID, req) + } + return &domain.RemoteWorkflow{ID: "wf-1", Name: req.Name}, nil + }}, + assert: requireNoRPCError, + }, + { + name: "workflow.update forwards request", + method: "workflow.update", + params: `{"workflow_id":"wf-1","is_enabled":true}`, + defaultGrant: "default-grant", + client: &fakeTemplateWorkflowClient{updateWorkflow: func(ctx context.Context, scope domain.RemoteScope, grantID, workflowID string, req *domain.UpdateRemoteWorkflowRequest) (*domain.RemoteWorkflow, error) { + if scope != domain.ScopeApplication || grantID != "" || workflowID != "wf-1" || req.IsEnabled == nil || *req.IsEnabled != enabled { + t.Fatalf("workflow update args = %q %q %q %+v, want forwarded update", scope, grantID, workflowID, req) + } + return &domain.RemoteWorkflow{ID: workflowID, IsEnabled: *req.IsEnabled}, nil + }}, + assert: requireNoRPCError, + }, + { + name: "workflow.delete returns deleted", + method: "workflow.delete", + params: `{"workflow_id":"wf-1"}`, + defaultGrant: "default-grant", + client: &fakeTemplateWorkflowClient{deleteWorkflow: func(ctx context.Context, scope domain.RemoteScope, grantID, workflowID string) error { + if scope != domain.ScopeApplication || grantID != "" || workflowID != "wf-1" { + t.Fatalf("workflow delete args = %q %q %q, want app empty grant wf-1", scope, grantID, workflowID) + } + return nil + }}, + assert: assertDeleted, + }, + } + + for _, spec := range templateWorkflowErrorSpecs(clientErr) { + tests = append(tests, spec...) + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewDispatcher() + RegisterTemplateWorkflowHandlers(d, tt.client, tt.defaultGrant) + + resp := dispatchTemplateWorkflowRequest(t, d, tt.method, tt.params) + tt.assert(t, resp) + }) + } +} + +func templateWorkflowErrorSpecs(clientErr error) [][]struct { + name string + method string + params string + defaultGrant string + client *fakeTemplateWorkflowClient + assert func(*testing.T, rpcTestResponse) +} { + invalidParams := func(t *testing.T, resp rpcTestResponse) { requireRPCErrorCode(t, resp, InvalidParams) } + internalError := func(t *testing.T, resp rpcTestResponse) { requireRPCErrorCode(t, resp, InternalError) } + return [][]struct { + name string + method string + params string + defaultGrant string + client *fakeTemplateWorkflowClient + assert func(*testing.T, rpcTestResponse) + }{ + errorSpec("template.list", `{"limit":1}`, `{"scope":"grant"}`, &fakeTemplateWorkflowClient{listRemoteTemplates: func(context.Context, domain.RemoteScope, string, *domain.CursorListParams) (*domain.RemoteTemplateListResponse, error) { + return nil, clientErr + }}, invalidParams, internalError), + errorSpec("template.get", `{"template_id":"tpl-1"}`, `{}`, &fakeTemplateWorkflowClient{getRemoteTemplate: func(context.Context, domain.RemoteScope, string, string) (*domain.RemoteTemplate, error) { + return nil, clientErr + }}, invalidParams, internalError), + errorSpec("template.create", `{"name":"Welcome"}`, `{"scope":"grant"}`, &fakeTemplateWorkflowClient{createRemoteTemplate: func(context.Context, domain.RemoteScope, string, *domain.CreateRemoteTemplateRequest) (*domain.RemoteTemplate, error) { + return nil, clientErr + }}, invalidParams, internalError), + errorSpec("template.update", `{"template_id":"tpl-1","name":"Updated"}`, `{}`, &fakeTemplateWorkflowClient{updateRemoteTemplate: func(context.Context, domain.RemoteScope, string, string, *domain.UpdateRemoteTemplateRequest) (*domain.RemoteTemplate, error) { + return nil, clientErr + }}, invalidParams, internalError), + errorSpec("template.delete", `{"template_id":"tpl-1"}`, `{}`, &fakeTemplateWorkflowClient{deleteRemoteTemplate: func(context.Context, domain.RemoteScope, string, string) error { + return clientErr + }}, invalidParams, internalError), + errorSpec("template.render", `{"template_id":"tpl-1"}`, `{}`, &fakeTemplateWorkflowClient{renderRemoteTemplate: func(context.Context, domain.RemoteScope, string, string, *domain.TemplateRenderRequest) (domain.TemplateRenderResult, error) { + return nil, clientErr + }}, invalidParams, internalError), + errorSpec("template.renderHTML", `{"body":"Hello","engine":"handlebars"}`, `{"scope":"grant"}`, &fakeTemplateWorkflowClient{renderRemoteTemplateHTML: func(context.Context, domain.RemoteScope, string, *domain.TemplateRenderHTMLRequest) (domain.TemplateRenderResult, error) { + return nil, clientErr + }}, invalidParams, internalError), + errorSpec("workflow.list", `{"limit":1}`, `{"scope":"grant"}`, &fakeTemplateWorkflowClient{listWorkflows: func(context.Context, domain.RemoteScope, string, *domain.CursorListParams) (*domain.RemoteWorkflowListResponse, error) { + return nil, clientErr + }}, invalidParams, internalError), + errorSpec("workflow.get", `{"workflow_id":"wf-1"}`, `{}`, &fakeTemplateWorkflowClient{getWorkflow: func(context.Context, domain.RemoteScope, string, string) (*domain.RemoteWorkflow, error) { + return nil, clientErr + }}, invalidParams, internalError), + errorSpec("workflow.create", `{"name":"Reminder"}`, `{"scope":"grant"}`, &fakeTemplateWorkflowClient{createWorkflow: func(context.Context, domain.RemoteScope, string, *domain.CreateRemoteWorkflowRequest) (*domain.RemoteWorkflow, error) { + return nil, clientErr + }}, invalidParams, internalError), + errorSpec("workflow.update", `{"workflow_id":"wf-1","name":"Updated"}`, `{}`, &fakeTemplateWorkflowClient{updateWorkflow: func(context.Context, domain.RemoteScope, string, string, *domain.UpdateRemoteWorkflowRequest) (*domain.RemoteWorkflow, error) { + return nil, clientErr + }}, invalidParams, internalError), + errorSpec("workflow.delete", `{"workflow_id":"wf-1"}`, `{}`, &fakeTemplateWorkflowClient{deleteWorkflow: func(context.Context, domain.RemoteScope, string, string) error { + return clientErr + }}, invalidParams, internalError), + } +} + +func errorSpec( + method string, + clientErrorParams string, + missingParams string, + client *fakeTemplateWorkflowClient, + invalidParams func(*testing.T, rpcTestResponse), + internalError func(*testing.T, rpcTestResponse), +) []struct { + name string + method string + params string + defaultGrant string + client *fakeTemplateWorkflowClient + assert func(*testing.T, rpcTestResponse) +} { + badScopeParams := strings.Replace(clientErrorParams, "{", `{"scope":"bad",`, 1) + + return []struct { + name string + method string + params string + defaultGrant string + client *fakeTemplateWorkflowClient + assert func(*testing.T, rpcTestResponse) + }{ + {name: method + " missing required param returns invalid params", method: method, params: missingParams, client: &fakeTemplateWorkflowClient{}, assert: invalidParams}, + {name: method + " bad scope returns invalid params", method: method, params: badScopeParams, defaultGrant: "default-grant", client: &fakeTemplateWorkflowClient{}, assert: invalidParams}, + {name: method + " client error maps to internal error", method: method, params: clientErrorParams, defaultGrant: "default-grant", client: client, assert: internalError}, + } +} + +func assertDeleted(t *testing.T, resp rpcTestResponse) { + t.Helper() + + requireNoRPCError(t, resp) + var result deletedResult + unmarshalResult(t, resp, &result) + if !result.Deleted { + t.Fatal("deleted = false, want true") + } +} + +func assertRenderBody(want string) func(*testing.T, rpcTestResponse) { + return func(t *testing.T, resp rpcTestResponse) { + t.Helper() + + requireNoRPCError(t, resp) + var result domain.TemplateRenderResult + unmarshalResult(t, resp, &result) + if result["body"] != want { + t.Fatalf("body = %v, want %q", result["body"], want) + } + } +} + +func dispatchTemplateWorkflowRequest(t *testing.T, d *Dispatcher, method, params string) rpcTestResponse { + t.Helper() + + raw := []byte(`{"jsonrpc":"2.0","id":1,"method":"` + method + `","params":` + params + `}`) + got := d.Dispatch(context.Background(), raw) + if got == nil { + t.Fatal("Dispatch() = nil, want response") + } + + var resp rpcTestResponse + if err := json.Unmarshal(got, &resp); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + if resp.JSONRPC != "2.0" { + t.Fatalf("JSONRPC = %q, want %q", resp.JSONRPC, "2.0") + } + return resp +} diff --git a/internal/adapters/rpcserver/handlers_thread.go b/internal/adapters/rpcserver/handlers_thread.go new file mode 100644 index 0000000..cab6745 --- /dev/null +++ b/internal/adapters/rpcserver/handlers_thread.go @@ -0,0 +1,80 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type threadListParams struct { + GrantID string `json:"grant_id,omitempty"` + Limit int `json:"limit,omitempty"` + PageToken string `json:"page_token,omitempty"` + LatestMessageAfter int64 `json:"latest_message_after,omitempty"` + Unread *bool `json:"unread,omitempty"` +} + +type threadListResult struct { + Threads []domain.Thread `json:"threads"` + NextCursor string `json:"next_cursor"` + HasMore bool `json:"has_more"` +} + +type threadGetParams struct { + GrantID string `json:"grant_id,omitempty"` + ThreadID string `json:"thread_id"` +} + +func RegisterThreadHandlers(d *Dispatcher, client ports.MessageClient, defaultGrant string) { + d.Register("thread.list", func(ctx context.Context, params json.RawMessage) (any, error) { + var p threadListParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + resp, err := client.GetThreadsWithCursor(ctx, grantID, &domain.ThreadQueryParams{ + Limit: p.Limit, + PageToken: p.PageToken, + Unread: p.Unread, + LatestMsgAfter: p.LatestMessageAfter, + }) + if err != nil { + return nil, fmt.Errorf("thread.list: %w", err) + } + + return threadListResult{ + Threads: resp.Data, + NextCursor: resp.Pagination.NextCursor, + HasMore: resp.Pagination.HasMore, + }, nil + }) + + d.Register("thread.get", func(ctx context.Context, params json.RawMessage) (any, error) { + var p threadGetParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.ThreadID == "" { + return nil, NewRPCError(InvalidParams, "thread_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + thread, err := client.GetThread(ctx, grantID, p.ThreadID) + if err != nil { + return nil, fmt.Errorf("thread.get: %w", err) + } + return thread, nil + }) +} diff --git a/internal/adapters/rpcserver/handlers_thread_test.go b/internal/adapters/rpcserver/handlers_thread_test.go new file mode 100644 index 0000000..5bc417f --- /dev/null +++ b/internal/adapters/rpcserver/handlers_thread_test.go @@ -0,0 +1,295 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type fakeThreadClient struct { + ports.MessageClient + + getThreads func(context.Context, string, *domain.ThreadQueryParams) ([]domain.Thread, error) + getThreadsWithCursor func(context.Context, string, *domain.ThreadQueryParams) (*domain.ThreadListResponse, error) + getThread func(context.Context, string, string) (*domain.Thread, error) +} + +func (f *fakeThreadClient) GetThreads(ctx context.Context, grantID string, params *domain.ThreadQueryParams) ([]domain.Thread, error) { + if f.getThreads == nil { + return nil, errors.New("unexpected GetThreads") + } + return f.getThreads(ctx, grantID, params) +} + +func (f *fakeThreadClient) GetThreadsWithCursor(ctx context.Context, grantID string, params *domain.ThreadQueryParams) (*domain.ThreadListResponse, error) { + if f.getThreadsWithCursor == nil { + return nil, errors.New("unexpected GetThreadsWithCursor") + } + return f.getThreadsWithCursor(ctx, grantID, params) +} + +func (f *fakeThreadClient) GetThread(ctx context.Context, grantID, threadID string) (*domain.Thread, error) { + if f.getThread == nil { + return nil, errors.New("unexpected GetThread") + } + return f.getThread(ctx, grantID, threadID) +} + +type threadRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + ID json.RawMessage `json:"id,omitempty"` + Result json.RawMessage `json:"result,omitempty"` + Error *RPCError `json:"error,omitempty"` +} + +func TestRegisterThreadHandlers(t *testing.T) { + clientErr := errors.New("client unavailable") + + tests := []struct { + name string + method string + params string + defaultGrant string + client *fakeThreadClient + assert func(*testing.T, threadRPCResponse) + }{ + { + name: "thread.list returns threads and next cursor", + method: "thread.list", + params: `{"limit":2}`, + defaultGrant: "default-grant", + client: &fakeThreadClient{ + getThreadsWithCursor: func(ctx context.Context, grantID string, params *domain.ThreadQueryParams) (*domain.ThreadListResponse, error) { + if grantID != "default-grant" { + t.Fatalf("grantID = %q, want %q", grantID, "default-grant") + } + return &domain.ThreadListResponse{ + Data: []domain.Thread{ + {ID: "thread-1", Subject: "Hello"}, + {ID: "thread-2", Subject: "World"}, + }, + Pagination: domain.Pagination{NextCursor: "cursor-2", HasMore: true}, + }, nil + }, + }, + assert: func(t *testing.T, resp threadRPCResponse) { + requireNoThreadRPCError(t, resp) + + var result struct { + Threads []domain.Thread `json:"threads"` + NextCursor string `json:"next_cursor"` + HasMore bool `json:"has_more"` + } + unmarshalThreadResult(t, resp, &result) + if len(result.Threads) != 2 || result.Threads[0].ID != "thread-1" || result.Threads[1].ID != "thread-2" { + t.Fatalf("threads = %#v, want thread-1 and thread-2", result.Threads) + } + if result.NextCursor != "cursor-2" { + t.Fatalf("next_cursor = %q, want cursor-2", result.NextCursor) + } + if !result.HasMore { + t.Fatal("has_more = false, want true") + } + }, + }, + { + name: "thread.list forwards query params and request grant", + method: "thread.list", + params: `{"grant_id":"request-grant","limit":25,"page_token":"cursor-1","latest_message_after":1710000000,"unread":false}`, + defaultGrant: "default-grant", + client: &fakeThreadClient{ + getThreadsWithCursor: func(ctx context.Context, grantID string, params *domain.ThreadQueryParams) (*domain.ThreadListResponse, error) { + if grantID != "request-grant" { + t.Fatalf("grantID = %q, want %q", grantID, "request-grant") + } + if params.Limit != 25 { + t.Fatalf("Limit = %d, want 25", params.Limit) + } + if params.PageToken != "cursor-1" { + t.Fatalf("PageToken = %q, want %q", params.PageToken, "cursor-1") + } + if params.LatestMsgAfter != 1710000000 { + t.Fatalf("LatestMsgAfter = %d, want 1710000000", params.LatestMsgAfter) + } + if params.Unread == nil || *params.Unread { + t.Fatalf("Unread = %#v, want pointer to false", params.Unread) + } + return &domain.ThreadListResponse{}, nil + }, + }, + assert: requireNoThreadRPCError, + }, + { + name: "thread.list leaves unread nil when omitted", + method: "thread.list", + params: `{"grant_id":"grant-1"}`, + client: &fakeThreadClient{ + getThreadsWithCursor: func(ctx context.Context, grantID string, params *domain.ThreadQueryParams) (*domain.ThreadListResponse, error) { + if params.Unread != nil { + t.Fatalf("Unread = %#v, want nil", params.Unread) + } + return &domain.ThreadListResponse{}, nil + }, + }, + assert: requireNoThreadRPCError, + }, + { + name: "thread.list missing grant returns invalid params", + method: "thread.list", + params: `{}`, + client: &fakeThreadClient{}, + assert: func(t *testing.T, resp threadRPCResponse) { + requireThreadRPCErrorCode(t, resp, InvalidParams) + }, + }, + { + name: "thread.list malformed params returns invalid params", + method: "thread.list", + params: `{"limit":"bad"}`, + client: &fakeThreadClient{}, + assert: func(t *testing.T, resp threadRPCResponse) { + requireThreadRPCErrorCode(t, resp, InvalidParams) + }, + }, + { + name: "thread.get with thread_id returns the thread", + method: "thread.get", + params: `{"grant_id":"grant-1","thread_id":"thread-1"}`, + client: &fakeThreadClient{ + getThread: func(ctx context.Context, grantID, threadID string) (*domain.Thread, error) { + if grantID != "grant-1" { + t.Fatalf("grantID = %q, want %q", grantID, "grant-1") + } + if threadID != "thread-1" { + t.Fatalf("threadID = %q, want %q", threadID, "thread-1") + } + return &domain.Thread{ID: "thread-1", Subject: "Hello"}, nil + }, + }, + assert: func(t *testing.T, resp threadRPCResponse) { + requireNoThreadRPCError(t, resp) + + var thread domain.Thread + unmarshalThreadResult(t, resp, &thread) + if thread.ID != "thread-1" || thread.Subject != "Hello" { + t.Fatalf("thread = %#v, want thread-1 Hello", thread) + } + }, + }, + { + name: "thread.get missing thread_id returns invalid params", + method: "thread.get", + params: `{"grant_id":"grant-1"}`, + defaultGrant: "grant-1", + client: &fakeThreadClient{}, + assert: func(t *testing.T, resp threadRPCResponse) { + requireThreadRPCErrorCode(t, resp, InvalidParams) + }, + }, + { + name: "client error maps to internal error", + method: "thread.get", + params: `{"thread_id":"thread-1"}`, + defaultGrant: "default-grant", + client: &fakeThreadClient{ + getThread: func(ctx context.Context, grantID, threadID string) (*domain.Thread, error) { + return nil, clientErr + }, + }, + assert: func(t *testing.T, resp threadRPCResponse) { + requireThreadRPCErrorCode(t, resp, InternalError) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewDispatcher() + RegisterThreadHandlers(d, tt.client, tt.defaultGrant) + + resp := dispatchThreadRequest(t, d, tt.method, tt.params) + tt.assert(t, resp) + }) + } +} + +func TestRegisterThreadHandlers_WrapsClientErrors(t *testing.T) { + clientErr := errors.New("client unavailable") + d := NewDispatcher() + RegisterThreadHandlers(d, &fakeThreadClient{ + getThreadsWithCursor: func(ctx context.Context, grantID string, params *domain.ThreadQueryParams) (*domain.ThreadListResponse, error) { + return nil, clientErr + }, + getThread: func(ctx context.Context, grantID, threadID string) (*domain.Thread, error) { + return nil, clientErr + }, + }, "grant-1") + + tests := []struct { + name string + method string + params string + }{ + {name: "list", method: "thread.list", params: `{}`}, + {name: "get", method: "thread.get", params: `{"thread_id":"thread-1"}`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := d.handlers[tt.method](context.Background(), json.RawMessage(tt.params)) + if !errors.Is(err, clientErr) { + t.Fatalf("handler error = %v, want wrapped %v", err, clientErr) + } + }) + } +} + +func dispatchThreadRequest(t *testing.T, d *Dispatcher, method, params string) threadRPCResponse { + t.Helper() + + raw := []byte(`{"jsonrpc":"2.0","id":1,"method":"` + method + `","params":` + params + `}`) + got := d.Dispatch(context.Background(), raw) + if got == nil { + t.Fatal("Dispatch() = nil, want response") + } + + var resp threadRPCResponse + if err := json.Unmarshal(got, &resp); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + if resp.JSONRPC != "2.0" { + t.Fatalf("JSONRPC = %q, want %q", resp.JSONRPC, "2.0") + } + return resp +} + +func requireNoThreadRPCError(t *testing.T, resp threadRPCResponse) { + t.Helper() + + if resp.Error != nil { + t.Fatalf("Error = %#v, want nil", resp.Error) + } +} + +func requireThreadRPCErrorCode(t *testing.T, resp threadRPCResponse, want int) { + t.Helper() + + if resp.Error == nil { + t.Fatal("Error = nil, want RPC error") + } + if resp.Error.Code != want { + t.Fatalf("Error.Code = %d, want %d", resp.Error.Code, want) + } +} + +func unmarshalThreadResult(t *testing.T, resp threadRPCResponse, dest any) { + t.Helper() + + if err := json.Unmarshal(resp.Result, dest); err != nil { + t.Fatalf("unmarshal result: %v", err) + } +} diff --git a/internal/adapters/rpcserver/handlers_thread_write.go b/internal/adapters/rpcserver/handlers_thread_write.go new file mode 100644 index 0000000..875f907 --- /dev/null +++ b/internal/adapters/rpcserver/handlers_thread_write.go @@ -0,0 +1,64 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type threadUpdateParams struct { + GrantID string `json:"grant_id,omitempty"` + ThreadID string `json:"thread_id"` + domain.UpdateMessageRequest +} + +type threadDeleteParams struct { + GrantID string `json:"grant_id,omitempty"` + ThreadID string `json:"thread_id"` +} + +func RegisterThreadWriteHandlers(d *Dispatcher, client ports.MessageClient, defaultGrant string) { + d.Register("thread.update", func(ctx context.Context, params json.RawMessage) (any, error) { + var p threadUpdateParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.ThreadID == "" { + return nil, NewRPCError(InvalidParams, "thread_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + thread, err := client.UpdateThread(ctx, grantID, p.ThreadID, &p.UpdateMessageRequest) + if err != nil { + return nil, fmt.Errorf("thread.update: %w", err) + } + return thread, nil + }) + + d.Register("thread.delete", func(ctx context.Context, params json.RawMessage) (any, error) { + var p threadDeleteParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + if p.ThreadID == "" { + return nil, NewRPCError(InvalidParams, "thread_id required", nil) + } + + grantID, err := resolveGrant(p.GrantID, defaultGrant) + if err != nil { + return nil, err + } + + if err := client.DeleteThread(ctx, grantID, p.ThreadID); err != nil { + return nil, fmt.Errorf("thread.delete: %w", err) + } + return deletedResult{Deleted: true}, nil + }) +} diff --git a/internal/adapters/rpcserver/handlers_thread_write_test.go b/internal/adapters/rpcserver/handlers_thread_write_test.go new file mode 100644 index 0000000..28d2b71 --- /dev/null +++ b/internal/adapters/rpcserver/handlers_thread_write_test.go @@ -0,0 +1,205 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type fakeThreadWriteClient struct { + ports.MessageClient + + updateThread func(context.Context, string, string, *domain.UpdateMessageRequest) (*domain.Thread, error) + deleteThread func(context.Context, string, string) error +} + +func (f *fakeThreadWriteClient) UpdateThread(ctx context.Context, grantID, threadID string, req *domain.UpdateMessageRequest) (*domain.Thread, error) { + if f.updateThread == nil { + return nil, errors.New("unexpected UpdateThread") + } + return f.updateThread(ctx, grantID, threadID, req) +} + +func (f *fakeThreadWriteClient) DeleteThread(ctx context.Context, grantID, threadID string) error { + if f.deleteThread == nil { + return errors.New("unexpected DeleteThread") + } + return f.deleteThread(ctx, grantID, threadID) +} + +func TestRegisterThreadWriteHandlers(t *testing.T) { + clientErr := errors.New("client unavailable") + + tests := []struct { + name string + method string + params string + defaultGrant string + client *fakeThreadWriteClient + assert func(*testing.T, threadRPCResponse) + }{ + { + name: "thread.update forwards update request", + method: "thread.update", + params: `{"grant_id":"grant-1","thread_id":"thread-1","unread":false,"starred":true,"folders":["inbox","important"]}`, + defaultGrant: "default-grant", + client: &fakeThreadWriteClient{ + updateThread: func(ctx context.Context, grantID, threadID string, req *domain.UpdateMessageRequest) (*domain.Thread, error) { + if grantID != "grant-1" { + t.Fatalf("grantID = %q, want grant-1", grantID) + } + if threadID != "thread-1" { + t.Fatalf("threadID = %q, want thread-1", threadID) + } + if req.Unread == nil || *req.Unread { + t.Fatalf("Unread = %#v, want pointer to false", req.Unread) + } + if req.Starred == nil || !*req.Starred { + t.Fatalf("Starred = %#v, want pointer to true", req.Starred) + } + if len(req.Folders) != 2 || req.Folders[0] != "inbox" || req.Folders[1] != "important" { + t.Fatalf("Folders = %#v, want inbox and important", req.Folders) + } + return &domain.Thread{ID: "thread-1", Unread: false, Starred: true}, nil + }, + }, + assert: func(t *testing.T, resp threadRPCResponse) { + requireNoThreadRPCError(t, resp) + + var thread domain.Thread + unmarshalThreadResult(t, resp, &thread) + if thread.ID != "thread-1" || thread.Unread || !thread.Starred { + t.Fatalf("thread = %#v, want updated thread", thread) + } + }, + }, + { + name: "thread.delete deletes thread", + method: "thread.delete", + params: `{"thread_id":"thread-1"}`, + defaultGrant: "default-grant", + client: &fakeThreadWriteClient{ + deleteThread: func(ctx context.Context, grantID, threadID string) error { + if grantID != "default-grant" { + t.Fatalf("grantID = %q, want default-grant", grantID) + } + if threadID != "thread-1" { + t.Fatalf("threadID = %q, want thread-1", threadID) + } + return nil + }, + }, + assert: func(t *testing.T, resp threadRPCResponse) { + requireNoThreadRPCError(t, resp) + + var result deletedResult + unmarshalThreadResult(t, resp, &result) + if !result.Deleted { + t.Fatal("deleted = false, want true") + } + }, + }, + { + name: "thread.update missing grant returns invalid params", + method: "thread.update", + params: `{"thread_id":"thread-1"}`, + client: &fakeThreadWriteClient{}, + assert: func(t *testing.T, resp threadRPCResponse) { + requireThreadRPCErrorCode(t, resp, InvalidParams) + }, + }, + { + name: "thread.update missing thread_id returns invalid params", + method: "thread.update", + params: `{"unread":true}`, + defaultGrant: "default-grant", + client: &fakeThreadWriteClient{}, + assert: func(t *testing.T, resp threadRPCResponse) { + requireThreadRPCErrorCode(t, resp, InvalidParams) + }, + }, + { + name: "thread.delete missing thread_id returns invalid params", + method: "thread.delete", + params: `{}`, + defaultGrant: "default-grant", + client: &fakeThreadWriteClient{}, + assert: func(t *testing.T, resp threadRPCResponse) { + requireThreadRPCErrorCode(t, resp, InvalidParams) + }, + }, + { + name: "thread.update client error maps to internal error", + method: "thread.update", + params: `{"thread_id":"thread-1","unread":true}`, + defaultGrant: "default-grant", + client: &fakeThreadWriteClient{ + updateThread: func(ctx context.Context, grantID, threadID string, req *domain.UpdateMessageRequest) (*domain.Thread, error) { + return nil, clientErr + }, + }, + assert: func(t *testing.T, resp threadRPCResponse) { + requireThreadRPCErrorCode(t, resp, InternalError) + }, + }, + { + name: "thread.delete client error maps to internal error", + method: "thread.delete", + params: `{"thread_id":"thread-1"}`, + defaultGrant: "default-grant", + client: &fakeThreadWriteClient{ + deleteThread: func(ctx context.Context, grantID, threadID string) error { + return clientErr + }, + }, + assert: func(t *testing.T, resp threadRPCResponse) { + requireThreadRPCErrorCode(t, resp, InternalError) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewDispatcher() + RegisterThreadWriteHandlers(d, tt.client, tt.defaultGrant) + + resp := dispatchThreadRequest(t, d, tt.method, tt.params) + tt.assert(t, resp) + }) + } +} + +func TestRegisterThreadWriteHandlers_WrapsClientErrors(t *testing.T) { + clientErr := errors.New("client unavailable") + d := NewDispatcher() + RegisterThreadWriteHandlers(d, &fakeThreadWriteClient{ + updateThread: func(ctx context.Context, grantID, threadID string, req *domain.UpdateMessageRequest) (*domain.Thread, error) { + return nil, clientErr + }, + deleteThread: func(ctx context.Context, grantID, threadID string) error { + return clientErr + }, + }, "grant-1") + + tests := []struct { + name string + method string + params string + }{ + {name: "update", method: "thread.update", params: `{"thread_id":"thread-1"}`}, + {name: "delete", method: "thread.delete", params: `{"thread_id":"thread-1"}`}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := d.handlers[tt.method](context.Background(), json.RawMessage(tt.params)) + if !errors.Is(err, clientErr) { + t.Fatalf("handler error = %v, want wrapped %v", err, clientErr) + } + }) + } +} diff --git a/internal/adapters/rpcserver/incremental.go b/internal/adapters/rpcserver/incremental.go new file mode 100644 index 0000000..59ed5dc --- /dev/null +++ b/internal/adapters/rpcserver/incremental.go @@ -0,0 +1,166 @@ +package rpcserver + +import ( + "cmp" + "context" + "encoding/json" + "slices" + "sync" + "time" +) + +type IntervalController struct { + mu sync.Mutex + fast time.Duration + idle time.Duration + focused bool +} + +func NewIntervalController(fast, idle time.Duration) *IntervalController { + return &IntervalController{fast: fast, idle: idle} +} + +func (c *IntervalController) SetFocused(focused bool) { + c.mu.Lock() + defer c.mu.Unlock() + c.focused = focused +} + +func (c *IntervalController) Current() time.Duration { + c.mu.Lock() + defer c.mu.Unlock() + if c.focused { + return c.fast + } + return c.idle +} + +type incrementalState struct { + cursor int64 + boundaryIDs map[string]struct{} +} + +func runTicker(ctx context.Context, interval time.Duration, onError func(error), pollOnce func(context.Context) error) error { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + if err := pollOnce(ctx); err != nil && onError != nil { + onError(err) + } + } + } +} + +func RunAdaptive(ctx context.Context, ctrl *IntervalController, onError func(error), pollOnce func(context.Context) error) error { + for { + timer := time.NewTimer(ctrl.Current()) + select { + case <-ctx.Done(): + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + return ctx.Err() + case <-timer.C: + } + + if err := pollOnce(ctx); err != nil && onError != nil { + onError(err) + } + } +} + +func RegisterFocusHandler(d *Dispatcher, ctrl *IntervalController) { + d.Register("client.focus", func(ctx context.Context, params json.RawMessage) (any, error) { + var p struct { + Focused bool `json:"focused"` + } + if err := decodeParams(params, &p); err != nil { + return nil, err + } + ctrl.SetFocused(p.Focused) + return nil, nil + }) +} + +func pollIncremental[T any]( + ctx context.Context, + st *incrementalState, + fetch func(context.Context, int64) ([]T, error), + tsOf func(T) int64, + idOf func(T) string, + method string, + payloadOf func(T) any, + notify NotifyFunc, +) error { + startCursor := st.cursor + queryAfter := startCursor + if queryAfter > 0 { + // The API filter is exclusive, so query cursor-1 and dedupe by id at the boundary second. + queryAfter-- + } + + rows, err := fetch(ctx, queryAfter) + if err != nil { + return err + } + + maxCursor := startCursor + if len(rows) > 0 { + maxRow := slices.MaxFunc(rows, func(a, b T) int { + return cmp.Compare(tsOf(a), tsOf(b)) + }) + maxCursor = max(maxCursor, tsOf(maxRow)) + } + + nextBoundaryIDs := make(map[string]struct{}) + if maxCursor == startCursor { + for id := range st.boundaryIDs { + nextBoundaryIDs[id] = struct{}{} + } + } + for _, row := range rows { + if tsOf(row) == maxCursor { + nextBoundaryIDs[idOf(row)] = struct{}{} + } + } + + slices.SortStableFunc(rows, func(a, b T) int { + if ts := cmp.Compare(tsOf(a), tsOf(b)); ts != 0 { + return ts + } + return cmp.Compare(idOf(a), idOf(b)) + }) + + emitted := make(map[string]struct{}) + for _, row := range rows { + ts := tsOf(row) + id := idOf(row) + if ts < startCursor { + continue + } + if _, ok := emitted[id]; ok { + continue + } + if ts == startCursor { + if _, ok := st.boundaryIDs[id]; ok { + continue + } + } + + if err := notify(method, payloadOf(row)); err != nil { + return err + } + emitted[id] = struct{}{} + } + st.cursor = maxCursor + st.boundaryIDs = nextBoundaryIDs + return nil +} diff --git a/internal/adapters/rpcserver/incremental_test.go b/internal/adapters/rpcserver/incremental_test.go new file mode 100644 index 0000000..93cc8c7 --- /dev/null +++ b/internal/adapters/rpcserver/incremental_test.go @@ -0,0 +1,108 @@ +package rpcserver + +import ( + "context" + "errors" + "sync" + "testing" + "time" +) + +func TestIntervalController_Current(t *testing.T) { + fast := 5 * time.Second + idle := 30 * time.Second + ctrl := NewIntervalController(fast, idle) + if got := ctrl.Current(); got != idle { + t.Fatalf("Current() = %v, want default idle %v", got, idle) + } + + tests := []struct { + name string + focused bool + want time.Duration + }{ + {name: "fast when focused", focused: true, want: fast}, + {name: "idle after focus clears", focused: false, want: idle}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl.SetFocused(tt.focused) + if got := ctrl.Current(); got != tt.want { + t.Fatalf("Current() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIntervalController_CurrentConcurrentAccess(t *testing.T) { + ctrl := NewIntervalController(time.Millisecond, time.Second) + start := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + <-start + for i := range 10_000 { + ctrl.SetFocused(i%2 == 0) + } + }() + + go func() { + defer wg.Done() + <-start + for range 10_000 { + _ = ctrl.Current() + } + }() + + close(start) + wg.Wait() +} + +func TestRunAdaptive_PollsReportsErrorsAndReturnsContextError(t *testing.T) { + ctrl := NewIntervalController(time.Millisecond, time.Millisecond) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + pollErr := errors.New("poll failed") + var calls int + var gotErrs []error + + err := RunAdaptive(ctx, ctrl, func(err error) { + gotErrs = append(gotErrs, err) + }, func(ctx context.Context) error { + calls++ + if calls == 1 { + return pollErr + } + cancel() + return nil + }) + + if !errors.Is(err, context.Canceled) { + t.Fatalf("RunAdaptive() error = %v, want %v", err, context.Canceled) + } + if calls < 2 { + t.Fatalf("pollOnce calls = %d, want at least 2", calls) + } + if len(gotErrs) != 1 || !errors.Is(gotErrs[0], pollErr) { + t.Fatalf("onError calls = %v, want %v", gotErrs, pollErr) + } +} + +func TestRegisterFocusHandler(t *testing.T) { + fast := time.Millisecond + idle := time.Second + ctrl := NewIntervalController(fast, idle) + d := NewDispatcher() + RegisterFocusHandler(d, ctrl) + + got := d.Dispatch(context.Background(), []byte(`{"jsonrpc":"2.0","method":"client.focus","params":{"focused":true}}`)) + if got != nil { + t.Fatalf("Dispatch() = %s, want nil", got) + } + if current := ctrl.Current(); current != fast { + t.Fatalf("Current() = %v, want %v", current, fast) + } +} diff --git a/internal/adapters/rpcserver/jsonrpc.go b/internal/adapters/rpcserver/jsonrpc.go new file mode 100644 index 0000000..72733a1 --- /dev/null +++ b/internal/adapters/rpcserver/jsonrpc.go @@ -0,0 +1,154 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "errors" + "fmt" +) + +const ( + ParseError = -32700 + InvalidRequest = -32600 + MethodNotFound = -32601 + InvalidParams = -32602 + InternalError = -32603 +) + +type Request struct { + JSONRPC string `json:"jsonrpc"` + ID *json.RawMessage `json:"id,omitempty"` + Method string `json:"method"` + Params json.RawMessage `json:"params,omitempty"` +} + +type RPCError struct { + Code int `json:"code"` + Message string `json:"message"` + Data any `json:"data,omitempty"` +} + +func NewRPCError(code int, message string, data any) *RPCError { + return &RPCError{Code: code, Message: message, Data: data} +} + +func (e *RPCError) Error() string { + return fmt.Sprintf("json-rpc error %d: %s", e.Code, e.Message) +} + +type Response struct { + JSONRPC string `json:"jsonrpc"` + ID *json.RawMessage `json:"id,omitempty"` + Result any `json:"result,omitempty"` + Error *RPCError `json:"error,omitempty"` +} + +type HandlerFunc func(ctx context.Context, params json.RawMessage) (any, error) + +type Dispatcher struct { + handlers map[string]HandlerFunc + LogError func(error) +} + +func NewDispatcher() *Dispatcher { + return &Dispatcher{handlers: make(map[string]HandlerFunc)} +} + +func (d *Dispatcher) Register(method string, h HandlerFunc) { + d.handlers[method] = h +} + +func (d *Dispatcher) Dispatch(ctx context.Context, raw []byte) []byte { + var req Request + if err := json.Unmarshal(raw, &req); err != nil { + return marshalResponse(Response{ + JSONRPC: "2.0", + ID: nullID(), + Error: NewRPCError(ParseError, "parse error", err.Error()), + }) + } + + if req.ID == nil { + if h, ok := d.handlers[req.Method]; ok { + if _, err := h(ctx, req.Params); err != nil { + d.logError(err) + } + } + return nil + } + + if req.JSONRPC != "2.0" { + return marshalResponse(errorResponse(nullID(), NewRPCError(InvalidRequest, "invalid request", nil))) + } + + if req.Method == "" { + return marshalResponse(errorResponse(nullID(), NewRPCError(InvalidRequest, "invalid request", nil))) + } + + h, ok := d.handlers[req.Method] + if !ok { + return marshalResponse(errorResponse(req.ID, NewRPCError(MethodNotFound, "method not found", nil))) + } + + result, err := h(ctx, req.Params) + if err != nil { + var rpcErr *RPCError + if !errors.As(err, &rpcErr) { + d.logError(err) + rpcErr = NewRPCError(InternalError, "internal error", nil) + } + return marshalResponse(errorResponse(req.ID, rpcErr)) + } + + return marshalResponse(Response{ + JSONRPC: "2.0", + ID: req.ID, + Result: result, + }) +} + +type Notification struct { + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + Params any `json:"params,omitempty"` +} + +func NewNotification(method string, params any) ([]byte, error) { + return json.Marshal(Notification{ + JSONRPC: "2.0", + Method: method, + Params: params, + }) +} + +func errorResponse(id *json.RawMessage, rpcErr *RPCError) Response { + return Response{ + JSONRPC: "2.0", + ID: id, + Error: rpcErr, + } +} + +func nullID() *json.RawMessage { + raw := json.RawMessage("null") + return &raw +} + +func (d *Dispatcher) logError(err error) { + if d.LogError != nil { + d.LogError(err) + } +} + +func marshalResponse(resp Response) []byte { + data, err := json.Marshal(resp) + if err != nil { + fallback := Response{ + JSONRPC: "2.0", + ID: resp.ID, + Error: NewRPCError(InternalError, "internal error", err.Error()), + } + data, _ = json.Marshal(fallback) + } + return data +} diff --git a/internal/adapters/rpcserver/jsonrpc_test.go b/internal/adapters/rpcserver/jsonrpc_test.go new file mode 100644 index 0000000..5d9d3fa --- /dev/null +++ b/internal/adapters/rpcserver/jsonrpc_test.go @@ -0,0 +1,265 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + "testing" +) + +func TestDispatcher_Dispatch(t *testing.T) { + tests := []struct { + name string + raw []byte + register func(*Dispatcher) + wantNil bool + wantID string + wantResult string + wantCode int + wantMsg string + wantData string + wantLog string + }{ + { + name: "valid request routes to handler and returns result", + raw: []byte(`{"jsonrpc":"2.0","id":1,"method":"ping","params":{"name":"Ada"}}`), + register: func(d *Dispatcher) { + d.Register("ping", func(ctx context.Context, params json.RawMessage) (any, error) { + return map[string]string{"status": "ok"}, nil + }) + }, + wantID: "1", + wantResult: `{"status":"ok"}`, + }, + { + name: "unknown notification returns nil", + raw: []byte(`{"jsonrpc":"2.0","method":"ping","params":{"name":"Ada"}}`), + wantNil: true, + }, + { + name: "unknown method returns method not found", + raw: []byte(`{"jsonrpc":"2.0","id":"abc","method":"missing"}`), + wantID: `"abc"`, + wantCode: MethodNotFound, + wantMsg: "method not found", + }, + { + name: "malformed JSON returns parse error", + raw: []byte(`{"jsonrpc":"2.0","id":1,"method":`), + wantID: "null", + wantCode: ParseError, + wantMsg: "parse error", + }, + { + name: "empty method returns invalid request", + raw: []byte(`{"jsonrpc":"2.0","id":1,"method":""}`), + wantID: "null", + wantCode: InvalidRequest, + wantMsg: "invalid request", + }, + { + name: "missing jsonrpc version returns invalid request", + raw: []byte(`{"id":1,"method":"ping"}`), + wantID: "null", + wantCode: InvalidRequest, + wantMsg: "invalid request", + }, + { + name: "wrong jsonrpc version returns invalid request", + raw: []byte(`{"jsonrpc":"1.0","id":1,"method":"ping"}`), + wantID: "null", + wantCode: InvalidRequest, + wantMsg: "invalid request", + }, + { + name: "plain handler error maps to internal error", + raw: []byte(`{"jsonrpc":"2.0","id":1,"method":"boom"}`), + register: func(d *Dispatcher) { + d.Register("boom", func(ctx context.Context, params json.RawMessage) (any, error) { + return nil, errors.New("database unavailable") + }) + }, + wantID: "1", + wantCode: InternalError, + wantMsg: "internal error", + wantLog: "database unavailable", + }, + { + name: "RPCError handler error passes through", + raw: []byte(`{"jsonrpc":"2.0","id":1,"method":"badParams"}`), + register: func(d *Dispatcher) { + d.Register("badParams", func(ctx context.Context, params json.RawMessage) (any, error) { + return nil, fmt.Errorf("checking params: %w", NewRPCError(InvalidParams, "bad params", map[string]string{"field": "email"})) + }) + }, + wantID: "1", + wantCode: InvalidParams, + wantMsg: "bad params", + wantData: `{"field":"email"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := NewDispatcher() + var loggedErr error + if tt.wantLog != "" { + d.LogError = func(err error) { + loggedErr = err + } + } + if tt.register != nil { + tt.register(d) + } + + got := d.Dispatch(context.Background(), tt.raw) + if tt.wantNil { + if got != nil { + t.Fatalf("Dispatch() = %s, want nil", got) + } + return + } + if got == nil { + t.Fatal("Dispatch() = nil, want response") + } + + var resp Response + if err := json.Unmarshal(got, &resp); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + if resp.JSONRPC != "2.0" { + t.Errorf("JSONRPC = %q, want %q", resp.JSONRPC, "2.0") + } + + if tt.wantID != "" { + wantID := `"id":` + tt.wantID + if !strings.Contains(string(got), wantID) { + t.Errorf("response = %s, want %s", got, wantID) + } + } + + if tt.wantCode != 0 { + if resp.Error == nil { + t.Fatal("Error = nil, want RPC error") + } + if resp.Error.Code != tt.wantCode { + t.Errorf("Error.Code = %d, want %d", resp.Error.Code, tt.wantCode) + } + if resp.Error.Message != tt.wantMsg { + t.Errorf("Error.Message = %q, want %q", resp.Error.Message, tt.wantMsg) + } + if tt.wantData != "" { + gotData, err := json.Marshal(resp.Error.Data) + if err != nil { + t.Fatalf("marshal error data: %v", err) + } + if string(gotData) != tt.wantData { + t.Errorf("Error.Data = %s, want %s", gotData, tt.wantData) + } + } + if tt.wantLog != "" { + if loggedErr == nil { + t.Fatal("LogError was not called") + } + if !strings.Contains(loggedErr.Error(), tt.wantLog) { + t.Errorf("logged error = %q, want it to contain %q", loggedErr, tt.wantLog) + } + } + return + } + + if resp.Error != nil { + t.Fatalf("Error = %#v, want nil", resp.Error) + } + gotResult, err := json.Marshal(resp.Result) + if err != nil { + t.Fatalf("marshal result: %v", err) + } + if string(gotResult) != tt.wantResult { + t.Errorf("Result = %s, want %s", gotResult, tt.wantResult) + } + }) + } +} + +func TestDispatcher_Dispatch_NotificationRunsHandler(t *testing.T) { + d := NewDispatcher() + called := false + + d.Register("ping", func(ctx context.Context, params json.RawMessage) (any, error) { + called = true + return "ignored", nil + }) + + got := d.Dispatch(context.Background(), []byte(`{"jsonrpc":"2.0","method":"ping","params":{"name":"Ada"}}`)) + if got != nil { + t.Fatalf("Dispatch() = %s, want nil", got) + } + if !called { + t.Fatal("handler was not called") + } +} + +func TestDispatcher_Dispatch_ParseErrorSerializesNullID(t *testing.T) { + tests := []struct { + name string + raw []byte + }{ + {name: "malformed object", raw: []byte(`{"jsonrpc":"2.0","id":1,"method":`)}, + {name: "not json", raw: []byte(`nope`)}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NewDispatcher().Dispatch(context.Background(), tt.raw) + if !strings.Contains(string(got), `"id":null`) { + t.Fatalf("Dispatch() = %s, want serialized null id", got) + } + }) + } +} + +func TestDispatcher_Dispatch_ForwardsRawParams(t *testing.T) { + d := NewDispatcher() + wantParams := `{"email":"ada@example.com","limit":5}` + called := false + + d.Register("inspect", func(ctx context.Context, params json.RawMessage) (any, error) { + called = true + if string(params) != wantParams { + t.Errorf("params = %s, want %s", params, wantParams) + } + return "ok", nil + }) + + got := d.Dispatch(context.Background(), []byte(`{"jsonrpc":"2.0","id":1,"method":"inspect","params":`+wantParams+`}`)) + if got == nil { + t.Fatal("Dispatch() = nil, want response") + } + if !called { + t.Fatal("handler was not called") + } +} + +func TestNewNotification(t *testing.T) { + got, err := NewNotification("progress", map[string]int{"percent": 50}) + if err != nil { + t.Fatalf("NewNotification() error = %v", err) + } + + var notif Notification + if err := json.Unmarshal(got, ¬if); err != nil { + t.Fatalf("unmarshal notification: %v", err) + } + if notif.JSONRPC != "2.0" { + t.Errorf("JSONRPC = %q, want %q", notif.JSONRPC, "2.0") + } + if notif.Method != "progress" { + t.Errorf("Method = %q, want %q", notif.Method, "progress") + } + if string(got) != `{"jsonrpc":"2.0","method":"progress","params":{"percent":50}}` { + t.Errorf("NewNotification() = %s", got) + } +} diff --git a/internal/adapters/rpcserver/poller_contacts.go b/internal/adapters/rpcserver/poller_contacts.go new file mode 100644 index 0000000..a37de05 --- /dev/null +++ b/internal/adapters/rpcserver/poller_contacts.go @@ -0,0 +1,192 @@ +package rpcserver + +import ( + "context" + "crypto/sha256" + "errors" + "fmt" + "sort" + "strconv" + "strings" + "time" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +const ( + contactPollLimit = 100 + maxContactPollPages = 500 +) + +type ContactPoller struct { + client ports.ContactClient + grantID string + seen map[string]string + seeded bool + notify NotifyFunc +} + +type contactUpdatedPayload struct { + ID string `json:"id"` + GivenName string `json:"given_name"` + Surname string `json:"surname"` + Emails []domain.ContactEmail `json:"emails"` + UpdatedAt int64 `json:"updated_at"` +} + +func NewContactPoller(client ports.ContactClient, grantID string, notify NotifyFunc) *ContactPoller { + return &ContactPoller{ + client: client, + grantID: grantID, + seen: make(map[string]string), + notify: notify, + } +} + +func (p *ContactPoller) PollOnce(ctx context.Context) error { + var contacts []domain.Contact + pageToken := "" + for page := range maxContactPollPages { + resp, err := p.client.GetContactsWithCursor(ctx, p.grantID, &domain.ContactQueryParams{ + Limit: contactPollLimit, + PageToken: pageToken, + }) + if err != nil { + return err + } + if resp == nil { + return errors.New("contact poll response is nil") + } + contacts = append(contacts, resp.Data...) + if resp.Pagination.NextCursor == "" || !resp.Pagination.HasMore { + break + } + if page == maxContactPollPages-1 { + return fmt.Errorf("contact poll truncated at %d pages; not committing snapshot", maxContactPollPages) + } + pageToken = resp.Pagination.NextCursor + } + // ponytail: cap full refetches at 500 pages (~50k contacts); truncation errors still backstop pathological pagination bugs. + + nextSeen := make(map[string]string, len(contacts)) + var changed []domain.Contact + for _, contact := range contacts { + fingerprint := contactFingerprint(contact) + nextSeen[contact.ID] = fingerprint + if !p.seeded { + continue + } + if lastSeen, ok := p.seen[contact.ID]; !ok || lastSeen != fingerprint { + changed = append(changed, contact) + } + } + + for _, contact := range changed { + if err := p.notify("contact.updated", contactUpdatedPayload{ + ID: contact.ID, + GivenName: contact.GivenName, + Surname: contact.Surname, + Emails: contact.Emails, + UpdatedAt: contact.UpdatedAt, + }); err != nil { + return err + } + } + + if p.seeded { + for id := range p.seen { + if _, ok := nextSeen[id]; ok { + continue + } + if err := p.notify("contact.deleted", map[string]string{"id": id}); err != nil { + return err + } + } + } + + p.seen = nextSeen + p.seeded = true + return nil +} + +func (p *ContactPoller) Run(ctx context.Context, interval time.Duration, onError func(error)) error { + return runTicker(ctx, interval, onError, p.PollOnce) +} + +func contactFingerprint(c domain.Contact) string { + records := []string{} + appendRecord := func(parts ...string) { + encoded := make([]string, 0, len(parts)*2) + for _, part := range parts { + encoded = append(encoded, strconv.Itoa(len(part)), part) + } + records = append(records, strings.Join(encoded, ":")) + } + appendSorted := func(label string, values []string) { + sort.Strings(values) + for _, value := range values { + appendRecord(label, value) + } + } + + appendRecord("given_name", c.GivenName) + appendRecord("middle_name", c.MiddleName) + appendRecord("surname", c.Surname) + appendRecord("suffix", c.Suffix) + appendRecord("nickname", c.Nickname) + appendRecord("birthday", c.Birthday) + appendRecord("company_name", c.CompanyName) + appendRecord("job_title", c.JobTitle) + appendRecord("manager_name", c.ManagerName) + appendRecord("notes", c.Notes) + appendRecord("picture_url", c.PictureURL) + appendRecord("picture", c.Picture) + appendRecord("source", c.Source) + + emails := make([]string, 0, len(c.Emails)) + for _, email := range c.Emails { + emails = append(emails, strings.Join([]string{email.Email, email.Type}, "\x00")) + } + appendSorted("emails", emails) + + phones := make([]string, 0, len(c.PhoneNumbers)) + for _, phone := range c.PhoneNumbers { + phones = append(phones, strings.Join([]string{phone.Number, phone.Type}, "\x00")) + } + appendSorted("phone_numbers", phones) + + webPages := make([]string, 0, len(c.WebPages)) + for _, webPage := range c.WebPages { + webPages = append(webPages, strings.Join([]string{webPage.URL, webPage.Type}, "\x00")) + } + appendSorted("web_pages", webPages) + + imAddresses := make([]string, 0, len(c.IMAddresses)) + for _, im := range c.IMAddresses { + imAddresses = append(imAddresses, strings.Join([]string{im.IMAddress, im.Type}, "\x00")) + } + appendSorted("im_addresses", imAddresses) + + addresses := make([]string, 0, len(c.PhysicalAddresses)) + for _, address := range c.PhysicalAddresses { + addresses = append(addresses, strings.Join([]string{ + address.Type, + address.StreetAddress, + address.City, + address.State, + address.PostalCode, + address.Country, + }, "\x00")) + } + appendSorted("physical_addresses", addresses) + + groups := make([]string, 0, len(c.Groups)) + for _, group := range c.Groups { + groups = append(groups, group.ID) + } + appendSorted("groups", groups) + + sum := sha256.Sum256([]byte(strings.Join(records, "\x00"))) + return fmt.Sprintf("%x", sum) +} diff --git a/internal/adapters/rpcserver/poller_contacts_test.go b/internal/adapters/rpcserver/poller_contacts_test.go new file mode 100644 index 0000000..1121c2f --- /dev/null +++ b/internal/adapters/rpcserver/poller_contacts_test.go @@ -0,0 +1,475 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "reflect" + "testing" + "time" + + "github.com/nylas/cli/internal/domain" +) + +type contactNotifyCall struct { + method string + params any +} + +type contactPollScript map[string]domain.ContactListResponse + +func TestContactPoller_PollOnce_FirstPollSeedsWithoutEmitting(t *testing.T) { + client := scriptedFakeContactClient(t, []contactPollScript{ + {"": {Data: []domain.Contact{pollContact("contact-1", 1), pollContact("contact-2", 1)}}}, + }, nil) + + var calls []contactNotifyCall + poller := NewContactPoller(client, "grant-123", func(method string, params any) error { + calls = append(calls, contactNotifyCall{method: method, params: params}) + return nil + }) + + if err := poller.PollOnce(context.Background()); err != nil { + t.Fatalf("PollOnce() error = %v", err) + } + if len(calls) != 0 { + t.Fatalf("notify calls = %#v, want none", calls) + } + if !poller.seeded { + t.Fatal("seeded = false, want true") + } + wantSeen := map[string]string{ + "contact-1": contactFingerprint(pollContact("contact-1", 1)), + "contact-2": contactFingerprint(pollContact("contact-2", 1)), + } + if !reflect.DeepEqual(poller.seen, wantSeen) { + t.Fatalf("seen = %#v, want both contacts", poller.seen) + } +} + +func TestContactPoller_PollOnce_SecondPollEmitsChangedContact(t *testing.T) { + changed := pollContact("contact-1", 2) + changed.GivenName = "Grace" + + client := scriptedFakeContactClient(t, []contactPollScript{ + {"": {Data: []domain.Contact{pollContact("contact-1", 1), pollContact("contact-2", 1)}}}, + {"": {Data: []domain.Contact{changed, pollContact("contact-2", 1)}}}, + }, nil) + + var calls []contactNotifyCall + poller := NewContactPoller(client, "grant-123", func(method string, params any) error { + calls = append(calls, contactNotifyCall{method: method, params: params}) + return nil + }) + + if err := poller.PollOnce(context.Background()); err != nil { + t.Fatalf("PollOnce() error = %v", err) + } + if err := poller.PollOnce(context.Background()); err != nil { + t.Fatalf("PollOnce() second error = %v", err) + } + + assertContactNotifyIDs(t, calls, []string{"contact-1"}) + gotPayload := calls[0].params.(contactUpdatedPayload) + wantPayload := contactUpdatedPayload{ + ID: "contact-1", + GivenName: "Grace", + Surname: "Lovelace", + Emails: []domain.ContactEmail{{Email: "ada@example.com", Type: "work"}}, + UpdatedAt: 2, + } + if !reflect.DeepEqual(gotPayload, wantPayload) { + t.Fatalf("payload = %#v, want %#v", gotPayload, wantPayload) + } + + payloadJSON, err := json.Marshal(gotPayload) + if err != nil { + t.Fatalf("marshal payload: %v", err) + } + wantJSON := `{"id":"contact-1","given_name":"Grace","surname":"Lovelace","emails":[{"email":"ada@example.com","type":"work"}],"updated_at":2}` + if string(payloadJSON) != wantJSON { + t.Fatalf("payload JSON = %s, want %s", payloadJSON, wantJSON) + } +} + +func TestContactPoller_PollOnce_SecondPollEmitsNewContact(t *testing.T) { + client := scriptedFakeContactClient(t, []contactPollScript{ + {"": {Data: []domain.Contact{pollContact("contact-1", 1)}}}, + {"": {Data: []domain.Contact{pollContact("contact-1", 1), pollContact("contact-2", 1)}}}, + }, nil) + + var calls []contactNotifyCall + poller := NewContactPoller(client, "grant-123", func(method string, params any) error { + calls = append(calls, contactNotifyCall{method: method, params: params}) + return nil + }) + + if err := poller.PollOnce(context.Background()); err != nil { + t.Fatalf("PollOnce() error = %v", err) + } + if err := poller.PollOnce(context.Background()); err != nil { + t.Fatalf("PollOnce() second error = %v", err) + } + + assertContactNotifyIDs(t, calls, []string{"contact-2"}) +} + +func TestContactPoller_PollOnce_UnchangedContactsEmitNothing(t *testing.T) { + client := scriptedFakeContactClient(t, []contactPollScript{ + {"": {Data: []domain.Contact{pollContact("contact-1", 1)}}}, + {"": {Data: []domain.Contact{pollContact("contact-1", 1)}}}, + }, nil) + + var calls []contactNotifyCall + poller := NewContactPoller(client, "grant-123", func(method string, params any) error { + calls = append(calls, contactNotifyCall{method: method, params: params}) + return nil + }) + + if err := poller.PollOnce(context.Background()); err != nil { + t.Fatalf("PollOnce() error = %v", err) + } + if err := poller.PollOnce(context.Background()); err != nil { + t.Fatalf("PollOnce() second error = %v", err) + } + + if len(calls) != 0 { + t.Fatalf("notify calls = %#v, want none", calls) + } +} + +func TestContactPoller_PollOnce_ContentFingerprintChangeDetection(t *testing.T) { + contentChanged := pollContact("contact-1", 1) + contentChanged.CompanyName = "Analytical Engines Ltd." + + timestampChanged := pollContact("contact-1", 2) + + tests := []struct { + name string + secondPoll domain.Contact + wantIDs []string + }{ + { + name: "content changes with unchanged updated_at emits update", + secondPoll: contentChanged, + wantIDs: []string{"contact-1"}, + }, + { + name: "updated_at changes with identical content emits nothing", + secondPoll: timestampChanged, + wantIDs: nil, + }, + { + name: "same content and same updated_at emits nothing", + secondPoll: pollContact("contact-1", 1), + wantIDs: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := scriptedFakeContactClient(t, []contactPollScript{ + {"": {Data: []domain.Contact{pollContact("contact-1", 1)}}}, + {"": {Data: []domain.Contact{tt.secondPoll}}}, + }, nil) + + var calls []contactNotifyCall + poller := NewContactPoller(client, "grant-123", func(method string, params any) error { + calls = append(calls, contactNotifyCall{method: method, params: params}) + return nil + }) + + if err := poller.PollOnce(context.Background()); err != nil { + t.Fatalf("PollOnce() error = %v", err) + } + if err := poller.PollOnce(context.Background()); err != nil { + t.Fatalf("PollOnce() second error = %v", err) + } + + assertContactNotifyIDs(t, calls, tt.wantIDs) + }) + } +} + +func TestContactPoller_PollOnce_SecondPollEmitsDeletedContact(t *testing.T) { + client := scriptedFakeContactClient(t, []contactPollScript{ + {"": {Data: []domain.Contact{pollContact("contact-1", 1), pollContact("contact-2", 1)}}}, + {"": {Data: []domain.Contact{pollContact("contact-1", 1)}}}, + }, nil) + + var calls []contactNotifyCall + poller := NewContactPoller(client, "grant-123", func(method string, params any) error { + calls = append(calls, contactNotifyCall{method: method, params: params}) + return nil + }) + + if err := poller.PollOnce(context.Background()); err != nil { + t.Fatalf("PollOnce() error = %v", err) + } + if err := poller.PollOnce(context.Background()); err != nil { + t.Fatalf("PollOnce() second error = %v", err) + } + + if len(calls) != 1 { + t.Fatalf("notify calls = %#v, want exactly one deletion", calls) + } + assertContactDeletedCall(t, calls[0], "contact-2") +} + +func TestContactPoller_PollOnce_FirstPollDoesNotEmitDeletes(t *testing.T) { + client := scriptedFakeContactClient(t, []contactPollScript{ + {"": {Data: []domain.Contact{pollContact("contact-1", 1)}}}, + }, nil) + + var calls []contactNotifyCall + poller := NewContactPoller(client, "grant-123", func(method string, params any) error { + calls = append(calls, contactNotifyCall{method: method, params: params}) + return nil + }) + poller.seen = map[string]string{"contact-2": contactFingerprint(pollContact("contact-2", 1))} + + if err := poller.PollOnce(context.Background()); err != nil { + t.Fatalf("PollOnce() error = %v", err) + } + + if len(calls) != 0 { + t.Fatalf("notify calls = %#v, want none", calls) + } + if !reflect.DeepEqual(poller.seen, map[string]string{"contact-1": contactFingerprint(pollContact("contact-1", 1))}) { + t.Fatalf("seen = %#v, want seeded snapshot only", poller.seen) + } +} + +func TestContactPoller_PollOnce_EmitsChangedAndDeletedContact(t *testing.T) { + changed := pollContact("contact-1", 2) + changed.JobTitle = "Mathematician" + + client := scriptedFakeContactClient(t, []contactPollScript{ + {"": {Data: []domain.Contact{pollContact("contact-1", 1), pollContact("contact-2", 1)}}}, + {"": {Data: []domain.Contact{changed}}}, + }, nil) + + var calls []contactNotifyCall + poller := NewContactPoller(client, "grant-123", func(method string, params any) error { + calls = append(calls, contactNotifyCall{method: method, params: params}) + return nil + }) + + if err := poller.PollOnce(context.Background()); err != nil { + t.Fatalf("PollOnce() error = %v", err) + } + if err := poller.PollOnce(context.Background()); err != nil { + t.Fatalf("PollOnce() second error = %v", err) + } + + if len(calls) != 2 { + t.Fatalf("notify calls = %#v, want update and deletion", calls) + } + assertContactNotifyIDs(t, calls[:1], []string{"contact-1"}) + assertContactDeletedCall(t, calls[1], "contact-2") +} + +func TestContactPoller_PollOnce_DrainsPages(t *testing.T) { + client := scriptedFakeContactClient(t, []contactPollScript{ + { + "": { + Data: []domain.Contact{pollContact("contact-1", 1)}, + Pagination: domain.Pagination{NextCursor: "page-2", HasMore: true}, + }, + "page-2": {Data: []domain.Contact{pollContact("contact-2", 1)}}, + }, + }, nil) + + var calls []contactNotifyCall + poller := NewContactPoller(client, "grant-123", func(method string, params any) error { + calls = append(calls, contactNotifyCall{method: method, params: params}) + return nil + }) + + if err := poller.PollOnce(context.Background()); err != nil { + t.Fatalf("PollOnce() error = %v", err) + } + + if len(calls) != 0 { + t.Fatalf("notify calls = %#v, want none", calls) + } + assertContactQueries(t, client.contactParams, []string{"", "page-2"}) + if !reflect.DeepEqual(client.contactGrantIDs, []string{"grant-123", "grant-123"}) { + t.Fatalf("grant IDs = %#v, want grant-123 for every page", client.contactGrantIDs) + } + wantSeen := map[string]string{ + "contact-1": contactFingerprint(pollContact("contact-1", 1)), + "contact-2": contactFingerprint(pollContact("contact-2", 1)), + } + if !reflect.DeepEqual(poller.seen, wantSeen) { + t.Fatalf("seen = %#v, want both pages", poller.seen) + } +} + +func TestContactPoller_PollOnce_ReturnsClientError(t *testing.T) { + clientErr := errors.New("api unavailable") + client := scriptedFakeContactClient(t, []contactPollScript{ + { + "": { + Data: []domain.Contact{pollContact("contact-1", 1)}, + Pagination: domain.Pagination{NextCursor: "page-2", HasMore: true}, + }, + }, + }, map[string]error{"page-2": clientErr}) + + called := false + poller := NewContactPoller(client, "grant-123", func(method string, params any) error { + called = true + return nil + }) + + err := poller.PollOnce(context.Background()) + if !errors.Is(err, clientErr) { + t.Fatalf("PollOnce() error = %v, want %v", err, clientErr) + } + if called { + t.Fatal("notify was called on client error") + } + if poller.seeded { + t.Fatal("seeded = true, want false") + } +} + +func TestContactPoller_PollOnce_ReturnsErrorWithoutCommitWhenPageCapTruncates(t *testing.T) { + pages := make(contactPollScript, maxContactPollPages) + pageToken := "" + for i := range maxContactPollPages { + nextCursor := fmt.Sprintf("page-%02d", i+1) + pages[pageToken] = domain.ContactListResponse{ + Data: []domain.Contact{pollContact(fmt.Sprintf("contact-%02d", i+1), int64(i+1))}, + Pagination: domain.Pagination{NextCursor: nextCursor, HasMore: true}, + } + pageToken = nextCursor + } + + client := scriptedFakeContactClient(t, []contactPollScript{pages}, nil) + var calls []contactNotifyCall + poller := NewContactPoller(client, "grant-123", func(method string, params any) error { + calls = append(calls, contactNotifyCall{method: method, params: params}) + return nil + }) + poller.seeded = true + poller.seen = map[string]string{"keep": "fingerprint"} + + err := poller.PollOnce(context.Background()) + if err == nil { + t.Fatal("PollOnce() error = nil, want truncation error") + } + if err.Error() != "contact poll truncated at 500 pages; not committing snapshot" { + t.Fatalf("PollOnce() error = %q, want truncation error", err) + } + if len(calls) != 0 { + t.Fatalf("notify calls = %#v, want none", calls) + } + if !poller.seeded { + t.Fatal("seeded = false, want unchanged true") + } + if !reflect.DeepEqual(poller.seen, map[string]string{"keep": "fingerprint"}) { + t.Fatalf("seen = %#v, want unchanged", poller.seen) + } +} + +func TestContactPoller_Run_ReturnsContextErrorOnCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + poller := NewContactPoller(&fakeContactClient{}, "grant-123", func(method string, params any) error { + t.Fatal("notify should not be called") + return nil + }) + + if err := poller.Run(ctx, time.Hour, nil); !errors.Is(err, context.Canceled) { + t.Fatalf("Run() error = %v, want %v", err, context.Canceled) + } +} + +func scriptedFakeContactClient(t *testing.T, polls []contactPollScript, errs map[string]error) *fakeContactClient { + t.Helper() + + poll := 0 + return &fakeContactClient{ + getContactsWithCursor: func(ctx context.Context, grantID string, params *domain.ContactQueryParams) (*domain.ContactListResponse, error) { + pageToken := "" + if params != nil { + pageToken = params.PageToken + } + if err := errs[pageToken]; err != nil { + return nil, err + } + if poll >= len(polls) { + t.Fatalf("unexpected poll %d page %q", poll+1, pageToken) + } + resp, ok := polls[poll][pageToken] + if !ok { + t.Fatalf("unexpected poll %d page %q", poll+1, pageToken) + } + if resp.Pagination.NextCursor == "" || !resp.Pagination.HasMore { + poll++ + } + return &resp, nil + }, + } +} + +func pollContact(id string, updatedAt int64) domain.Contact { + return domain.Contact{ + ID: id, + GivenName: "Ada", + Surname: "Lovelace", + Emails: []domain.ContactEmail{{Email: "ada@example.com", Type: "work"}}, + UpdatedAt: updatedAt, + } +} + +func assertContactNotifyIDs(t *testing.T, calls []contactNotifyCall, want []string) { + t.Helper() + + var got []string + for i, call := range calls { + if call.method != "contact.updated" { + t.Fatalf("notify call %d method = %q, want contact.updated", i, call.method) + } + payload, ok := call.params.(contactUpdatedPayload) + if !ok { + t.Fatalf("notify call %d payload type = %T, want contactUpdatedPayload", i, call.params) + } + got = append(got, payload.ID) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("notify IDs = %#v, want %#v", got, want) + } +} + +func assertContactDeletedCall(t *testing.T, call contactNotifyCall, wantID string) { + t.Helper() + + if call.method != "contact.deleted" { + t.Fatalf("notify method = %q, want contact.deleted", call.method) + } + payload, ok := call.params.(map[string]string) + if !ok { + t.Fatalf("notify payload type = %T, want map[string]string", call.params) + } + if !reflect.DeepEqual(payload, map[string]string{"id": wantID}) { + t.Fatalf("notify payload = %#v, want id %q", payload, wantID) + } +} + +func assertContactQueries(t *testing.T, got []domain.ContactQueryParams, wantPageTokens []string) { + t.Helper() + + if len(got) != len(wantPageTokens) { + t.Fatalf("query count = %d, want %d: %#v", len(got), len(wantPageTokens), got) + } + for i, pageToken := range wantPageTokens { + if got[i].Limit != contactPollLimit || got[i].PageToken != pageToken { + t.Fatalf("query %d = %+v, want limit %d page_token %q", i, got[i], contactPollLimit, pageToken) + } + } +} diff --git a/internal/adapters/rpcserver/poller_events.go b/internal/adapters/rpcserver/poller_events.go new file mode 100644 index 0000000..38a5a8f --- /dev/null +++ b/internal/adapters/rpcserver/poller_events.go @@ -0,0 +1,90 @@ +package rpcserver + +import ( + "context" + "errors" + "fmt" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +const ( + eventPollLimit = 50 + maxEventPollPages = 20 +) + +type EventPoller struct { + incrementalState + client ports.CalendarClient + grantID string + calendarID string + notify NotifyFunc +} + +type eventUpdatedPayload struct { + ID string `json:"id"` + CalendarID string `json:"calendar_id"` + Title string `json:"title"` + When domain.EventWhen `json:"when"` + Status string `json:"status"` + UpdatedAt int64 `json:"updated_at"` +} + +func NewEventPoller(client ports.CalendarClient, grantID, calendarID string, since int64, notify NotifyFunc) *EventPoller { + if calendarID == "" { + calendarID = "primary" + } + return &EventPoller{ + incrementalState: incrementalState{cursor: since}, + client: client, + grantID: grantID, + calendarID: calendarID, + notify: notify, + } +} + +func (p *EventPoller) PollOnce(ctx context.Context) error { + return pollIncremental(ctx, &p.incrementalState, p.fetch, func(event domain.Event) int64 { + return event.UpdatedAt.Unix() + }, func(event domain.Event) string { + return event.ID + }, "event.updated", func(event domain.Event) any { + return eventUpdatedPayload{ + ID: event.ID, + CalendarID: p.calendarID, + Title: event.Title, + When: event.When, + Status: event.Status, + UpdatedAt: event.UpdatedAt.Unix(), + } + }, p.notify) +} + +func (p *EventPoller) fetch(ctx context.Context, queryAfter int64) ([]domain.Event, error) { + var events []domain.Event + pageToken := "" + for page := range maxEventPollPages { + resp, err := p.client.GetEventsWithCursor(ctx, p.grantID, p.calendarID, &domain.EventQueryParams{ + Limit: eventPollLimit, + PageToken: pageToken, + UpdatedAfter: queryAfter, + }) + if err != nil { + return nil, err + } + if resp == nil { + return nil, errors.New("event poll response is nil") + } + events = append(events, resp.Data...) + if resp.Pagination.NextCursor == "" || !resp.Pagination.HasMore { + break + } + if page == maxEventPollPages-1 { + return nil, fmt.Errorf("event poll truncated at %d pages; not advancing cursor", maxEventPollPages) + } + pageToken = resp.Pagination.NextCursor + } + // ponytail: cap polling bursts at 20 pages; webhooks are the real fix for larger calendar spikes. + return events, nil +} diff --git a/internal/adapters/rpcserver/poller_events_test.go b/internal/adapters/rpcserver/poller_events_test.go new file mode 100644 index 0000000..e4d6c2b --- /dev/null +++ b/internal/adapters/rpcserver/poller_events_test.go @@ -0,0 +1,322 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "reflect" + "testing" + "time" + + "github.com/nylas/cli/internal/domain" +) + +func TestEventPoller_PollOnce_DrainsPagesEmitsAllNewEvents(t *testing.T) { + client, queries := fakeEventPollPages(map[string][]domain.EventListResponse{ + "": {{ + Data: pollEvents(175, 126), + Pagination: domain.Pagination{NextCursor: "page-2", HasMore: true}, + }}, + "page-2": {{ + Data: pollEvents(125, 106), + Pagination: domain.Pagination{NextCursor: "page-3", HasMore: true}, + }}, + "page-3": {{ + Data: pollEvents(105, 101), + }}, + }, nil) + + var calls []notifyCall + poller := NewEventPoller(client, "grant-123", "cal-123", 100, func(method string, params any) error { + calls = append(calls, notifyCall{method: method, params: params}) + return nil + }) + + if err := poller.PollOnce(context.Background()); err != nil { + t.Fatalf("PollOnce() error = %v", err) + } + + var wantIDs []string + for ts := int64(101); ts <= 175; ts++ { + wantIDs = append(wantIDs, fmt.Sprintf("event-%03d", ts)) + } + assertEventNotifyIDs(t, calls, wantIDs) + assertUniqueEventNotifyIDs(t, calls) + assertEventQueries(t, *queries, []eventWantQuery{ + {updatedAfter: 99, pageToken: "", calendarID: "cal-123"}, + {updatedAfter: 99, pageToken: "page-2", calendarID: "cal-123"}, + {updatedAfter: 99, pageToken: "page-3", calendarID: "cal-123"}, + }) + if poller.cursor != 175 { + t.Fatalf("cursor = %d, want 175", poller.cursor) + } +} + +func TestEventPoller_PollOnce_EmitsSameSecondBoundaryEventOnce(t *testing.T) { + client, queries := fakeEventPollPages(map[string][]domain.EventListResponse{ + "": { + {Data: []domain.Event{pollEvent("boundary-a", 105)}}, + {Data: []domain.Event{ + pollEvent("boundary-a", 105), + pollEvent("boundary-b", 105), + pollEvent("newer-c", 106), + }}, + {Data: []domain.Event{ + pollEvent("boundary-a", 105), + pollEvent("boundary-b", 105), + pollEvent("newer-c", 106), + }}, + }, + }, nil) + + var calls []notifyCall + poller := NewEventPoller(client, "grant-123", "", 100, func(method string, params any) error { + calls = append(calls, notifyCall{method: method, params: params}) + return nil + }) + + for i := 0; i < 3; i++ { + if err := poller.PollOnce(context.Background()); err != nil { + t.Fatalf("PollOnce() #%d error = %v", i+1, err) + } + } + + assertEventNotifyIDs(t, calls, []string{"boundary-a", "boundary-b", "newer-c"}) + assertUniqueEventNotifyIDs(t, calls) + assertEventQueries(t, *queries, []eventWantQuery{ + {updatedAfter: 99, pageToken: "", calendarID: "primary"}, + {updatedAfter: 104, pageToken: "", calendarID: "primary"}, + {updatedAfter: 105, pageToken: "", calendarID: "primary"}, + }) +} + +func TestEventPoller_PollOnce_NoNewSecondPollEmitsNothing(t *testing.T) { + client, queries := fakeEventPollPages(map[string][]domain.EventListResponse{ + "": { + {Data: []domain.Event{{ + ID: "new-1", + CalendarID: "cal-123", + Title: "Sync", + When: domain.EventWhen{StartTime: 1010, EndTime: 1020, Object: "timespan"}, + Status: "confirmed", + UpdatedAt: time.Unix(101, 0), + }}}, + {}, + }, + }, nil) + + var calls []notifyCall + poller := NewEventPoller(client, "grant-123", "cal-123", 100, func(method string, params any) error { + calls = append(calls, notifyCall{method: method, params: params}) + return nil + }) + + if err := poller.PollOnce(context.Background()); err != nil { + t.Fatalf("PollOnce() error = %v", err) + } + if err := poller.PollOnce(context.Background()); err != nil { + t.Fatalf("PollOnce() second error = %v", err) + } + + assertEventNotifyIDs(t, calls, []string{"new-1"}) + gotPayload, ok := calls[0].params.(eventUpdatedPayload) + if !ok { + t.Fatalf("payload type = %T, want eventUpdatedPayload", calls[0].params) + } + wantPayload := eventUpdatedPayload{ + ID: "new-1", + CalendarID: "cal-123", + Title: "Sync", + When: domain.EventWhen{StartTime: 1010, EndTime: 1020, Object: "timespan"}, + Status: "confirmed", + UpdatedAt: 101, + } + if !reflect.DeepEqual(gotPayload, wantPayload) { + t.Fatalf("payload = %#v, want %#v", gotPayload, wantPayload) + } + + payloadJSON, err := json.Marshal(gotPayload) + if err != nil { + t.Fatalf("marshal payload: %v", err) + } + wantJSON := `{"id":"new-1","calendar_id":"cal-123","title":"Sync","when":{"start_time":1010,"end_time":1020,"object":"timespan"},"status":"confirmed","updated_at":101}` + if string(payloadJSON) != wantJSON { + t.Fatalf("payload JSON = %s, want %s", payloadJSON, wantJSON) + } + assertEventQueries(t, *queries, []eventWantQuery{ + {updatedAfter: 99, pageToken: "", calendarID: "cal-123"}, + {updatedAfter: 100, pageToken: "", calendarID: "cal-123"}, + }) +} + +func TestEventPoller_PollOnce_ReturnsClientErrorFromLaterPage(t *testing.T) { + clientErr := errors.New("api unavailable") + client, _ := fakeEventPollPages(map[string][]domain.EventListResponse{ + "": {{ + Data: []domain.Event{pollEvent("new", 101)}, + Pagination: domain.Pagination{NextCursor: "page-2", HasMore: true}, + }}, + }, map[string]error{"page-2": clientErr}) + called := false + poller := NewEventPoller(client, "grant-123", "cal-123", 100, func(method string, params any) error { + called = true + return nil + }) + + err := poller.PollOnce(context.Background()) + if !errors.Is(err, clientErr) { + t.Fatalf("PollOnce() error = %v, want %v", err, clientErr) + } + if called { + t.Fatal("notify was called on client error") + } + if poller.cursor != 100 { + t.Fatalf("cursor = %d, want 100", poller.cursor) + } +} + +func TestEventPoller_PollOnce_ReturnsErrorWhenPageDrainTruncates(t *testing.T) { + tests := []struct { + name string + }{ + {name: "more pages after cap"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pages := make(map[string][]domain.EventListResponse, maxEventPollPages) + pageToken := "" + for page := range maxEventPollPages { + nextToken := fmt.Sprintf("page-%d", page+1) + pages[pageToken] = []domain.EventListResponse{{ + Data: []domain.Event{pollEvent(fmt.Sprintf("event-%03d", page), int64(101+page))}, + Pagination: domain.Pagination{NextCursor: nextToken, HasMore: true}, + }} + pageToken = nextToken + } + client, _ := fakeEventPollPages(pages, nil) + + var calls []notifyCall + poller := NewEventPoller(client, "grant-123", "cal-123", 100, func(method string, params any) error { + calls = append(calls, notifyCall{method: method, params: params}) + return nil + }) + + err := poller.PollOnce(context.Background()) + wantErr := fmt.Sprintf("event poll truncated at %d pages; not advancing cursor", maxEventPollPages) + if err == nil || err.Error() != wantErr { + t.Fatalf("PollOnce() error = %v, want %q", err, wantErr) + } + if poller.cursor != 100 { + t.Fatalf("cursor = %d, want 100", poller.cursor) + } + if len(calls) != 0 { + t.Fatalf("notify calls = %#v, want none", calls) + } + }) + } +} + +func fakeEventPollPages(pages map[string][]domain.EventListResponse, errs map[string]error) (*fakeCalendarClient, *[]eventWantQuery) { + var queries []eventWantQuery + return &fakeCalendarClient{ + getEventsWithCursor: func(ctx context.Context, grantID, calendarID string, params *domain.EventQueryParams) (*domain.EventListResponse, error) { + pageToken := "" + if params != nil { + pageToken = params.PageToken + queries = append(queries, eventWantQuery{ + updatedAfter: params.UpdatedAfter, + pageToken: params.PageToken, + calendarID: calendarID, + }) + } + if err := errs[pageToken]; err != nil { + return nil, err + } + if len(pages[pageToken]) == 0 { + return &domain.EventListResponse{}, nil + } + resp := pages[pageToken][0] + pages[pageToken] = pages[pageToken][1:] + if params != nil { + resp.Data = filterEventsAfter(resp.Data, params.UpdatedAfter) + } + return &resp, nil + }, + }, &queries +} + +func pollEvents(newest, oldest int64) []domain.Event { + var events []domain.Event + for ts := newest; ts >= oldest; ts-- { + events = append(events, pollEvent(fmt.Sprintf("event-%03d", ts), ts)) + } + return events +} + +func pollEvent(id string, unix int64) domain.Event { + return domain.Event{ + ID: id, + Title: id, + UpdatedAt: time.Unix(unix, 0), + } +} + +func filterEventsAfter(events []domain.Event, updatedAfter int64) []domain.Event { + var filtered []domain.Event + for _, event := range events { + if event.UpdatedAt.Unix() > updatedAfter { + filtered = append(filtered, event) + } + } + return filtered +} + +func assertEventNotifyIDs(t *testing.T, calls []notifyCall, want []string) { + t.Helper() + var got []string + for i, call := range calls { + if call.method != "event.updated" { + t.Fatalf("notify call %d method = %q, want event.updated", i, call.method) + } + payload, ok := call.params.(eventUpdatedPayload) + if !ok { + t.Fatalf("notify call %d payload type = %T, want eventUpdatedPayload", i, call.params) + } + got = append(got, payload.ID) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("notify IDs = %#v, want %#v", got, want) + } +} + +func assertUniqueEventNotifyIDs(t *testing.T, calls []notifyCall) { + t.Helper() + seen := make(map[string]struct{}) + for _, call := range calls { + payload := call.params.(eventUpdatedPayload) + if _, ok := seen[payload.ID]; ok { + t.Fatalf("duplicate notify ID %q", payload.ID) + } + seen[payload.ID] = struct{}{} + } +} + +type eventWantQuery struct { + updatedAfter int64 + pageToken string + calendarID string +} + +func assertEventQueries(t *testing.T, got []eventWantQuery, want []eventWantQuery) { + t.Helper() + if len(got) != len(want) { + t.Fatalf("query count = %d, want %d: %#v", len(got), len(want), got) + } + for i := range want { + if got[i].updatedAfter != want[i].updatedAfter || got[i].pageToken != want[i].pageToken || got[i].calendarID != want[i].calendarID { + t.Fatalf("query %d = %+v, want %+v", i, got[i], want[i]) + } + } +} diff --git a/internal/adapters/rpcserver/poller_messages.go b/internal/adapters/rpcserver/poller_messages.go new file mode 100644 index 0000000..8e54ab4 --- /dev/null +++ b/internal/adapters/rpcserver/poller_messages.go @@ -0,0 +1,100 @@ +package rpcserver + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +const ( + messagePollLimit = 50 + maxMessagePollPages = 20 +) + +// NotifyFunc emits a server->client notification. +type NotifyFunc func(method string, params any) error + +type MessagePoller struct { + incrementalState + client ports.MessageClient + grantID string + notify NotifyFunc +} + +type messageReceivedPayload struct { + ID string `json:"id"` + GrantID string `json:"grant_id"` + Subject string `json:"subject"` + Snippet string `json:"snippet"` + From []domain.EmailParticipant `json:"from"` + Date int64 `json:"date"` + Unread bool `json:"unread"` + Folders []string `json:"folders"` +} + +// NewMessagePoller polls messages newer than since. +func NewMessagePoller(client ports.MessageClient, grantID string, since int64, notify NotifyFunc) *MessagePoller { + return &MessagePoller{ + incrementalState: incrementalState{cursor: since}, + client: client, + grantID: grantID, + notify: notify, + } +} + +// PollOnce runs one polling cycle. +func (p *MessagePoller) PollOnce(ctx context.Context) error { + return pollIncremental(ctx, &p.incrementalState, p.fetch, func(msg domain.Message) int64 { + return msg.Date.Unix() + }, func(msg domain.Message) string { + return msg.ID + }, "message.received", func(msg domain.Message) any { + return messageReceivedPayload{ + ID: msg.ID, + GrantID: msg.GrantID, + Subject: msg.Subject, + Snippet: msg.Snippet, + From: msg.From, + Date: msg.Date.Unix(), + Unread: msg.Unread, + Folders: msg.Folders, + } + }, p.notify) +} + +func (p *MessagePoller) fetch(ctx context.Context, queryAfter int64) ([]domain.Message, error) { + var messages []domain.Message + pageToken := "" + for page := range maxMessagePollPages { + resp, err := p.client.GetMessagesWithCursor(ctx, p.grantID, &domain.MessageQueryParams{ + Limit: messagePollLimit, + PageToken: pageToken, + ReceivedAfter: queryAfter, + }) + if err != nil { + return nil, err + } + if resp == nil { + return nil, errors.New("message poll response is nil") + } + messages = append(messages, resp.Data...) + if resp.Pagination.NextCursor == "" || !resp.Pagination.HasMore { + break + } + if page == maxMessagePollPages-1 { + return nil, fmt.Errorf("message poll truncated at %d pages; not advancing cursor", maxMessagePollPages) + } + pageToken = resp.Pagination.NextCursor + } + // ponytail: cap polling bursts at 20 pages; webhooks are the real fix for larger inbox spikes. + return messages, nil +} + +// Run polls until ctx is cancelled. +func (p *MessagePoller) Run(ctx context.Context, interval time.Duration, onError func(error)) error { + return runTicker(ctx, interval, onError, p.PollOnce) +} diff --git a/internal/adapters/rpcserver/poller_messages_test.go b/internal/adapters/rpcserver/poller_messages_test.go new file mode 100644 index 0000000..4187f68 --- /dev/null +++ b/internal/adapters/rpcserver/poller_messages_test.go @@ -0,0 +1,378 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "reflect" + "testing" + "time" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +type fakePollClient struct { + ports.MessageClient + pages map[string][]domain.MessageListResponse + errs map[string]error + grantIDs []string + params []domain.MessageQueryParams +} + +func (f *fakePollClient) GetMessagesWithCursor(ctx context.Context, grantID string, params *domain.MessageQueryParams) (*domain.MessageListResponse, error) { + f.grantIDs = append(f.grantIDs, grantID) + pageToken := "" + if params != nil { + pageToken = params.PageToken + f.params = append(f.params, *params) + } + if err := f.errs[pageToken]; err != nil { + return nil, err + } + if len(f.pages[pageToken]) == 0 { + return &domain.MessageListResponse{}, nil + } + resp := f.pages[pageToken][0] + f.pages[pageToken] = f.pages[pageToken][1:] + if params != nil { + resp.Data = filterMessagesAfter(resp.Data, params.ReceivedAfter) + } + return &resp, nil +} + +type notifyCall struct { + method string + params any +} + +func TestMessagePoller_PollOnce_DrainsPagesEmitsAllNewMessages(t *testing.T) { + client := &fakePollClient{ + pages: map[string][]domain.MessageListResponse{ + "": {{ + Data: pollMessages(175, 126), + Pagination: domain.Pagination{NextCursor: "page-2", HasMore: true}, + }}, + "page-2": {{ + Data: pollMessages(125, 106), + Pagination: domain.Pagination{NextCursor: "page-3", HasMore: true}, + }}, + "page-3": {{ + Data: pollMessages(105, 101), + }}, + }, + } + + var calls []notifyCall + poller := NewMessagePoller(client, "grant-123", 100, func(method string, params any) error { + calls = append(calls, notifyCall{method: method, params: params}) + return nil + }) + + if err := poller.PollOnce(context.Background()); err != nil { + t.Fatalf("PollOnce() error = %v", err) + } + + var wantIDs []string + for ts := int64(101); ts <= 175; ts++ { + wantIDs = append(wantIDs, fmt.Sprintf("msg-%03d", ts)) + } + assertNotifyIDs(t, calls, wantIDs) + assertUniqueNotifyIDs(t, calls) + assertQueries(t, client.params, []wantQuery{ + {receivedAfter: 99, pageToken: ""}, + {receivedAfter: 99, pageToken: "page-2"}, + {receivedAfter: 99, pageToken: "page-3"}, + }) + if !reflect.DeepEqual(client.grantIDs, []string{"grant-123", "grant-123", "grant-123"}) { + t.Fatalf("grant IDs = %#v, want grant-123 for every page", client.grantIDs) + } + if poller.cursor != 175 { + t.Fatalf("cursor = %d, want 175", poller.cursor) + } +} + +func TestMessagePoller_PollOnce_EmitsSameSecondBoundaryMessageOnce(t *testing.T) { + base := int64(100) + client := &fakePollClient{ + pages: map[string][]domain.MessageListResponse{ + "": { + {Data: []domain.Message{pollMessage("boundary-a", 105)}}, + {Data: []domain.Message{ + pollMessage("boundary-a", 105), + pollMessage("boundary-b", 105), + pollMessage("newer-c", 106), + }}, + {Data: []domain.Message{ + pollMessage("boundary-a", 105), + pollMessage("boundary-b", 105), + pollMessage("newer-c", 106), + }}, + }, + }, + } + + var calls []notifyCall + poller := NewMessagePoller(client, "grant-123", base, func(method string, params any) error { + calls = append(calls, notifyCall{method: method, params: params}) + return nil + }) + + for i := 0; i < 3; i++ { + if err := poller.PollOnce(context.Background()); err != nil { + t.Fatalf("PollOnce() #%d error = %v", i+1, err) + } + } + + assertNotifyIDs(t, calls, []string{"boundary-a", "boundary-b", "newer-c"}) + assertUniqueNotifyIDs(t, calls) + assertQueries(t, client.params, []wantQuery{ + {receivedAfter: 99, pageToken: ""}, + {receivedAfter: 104, pageToken: ""}, + {receivedAfter: 105, pageToken: ""}, + }) +} + +func TestMessagePoller_PollOnce_NoNewSecondPollEmitsNothing(t *testing.T) { + client := &fakePollClient{ + pages: map[string][]domain.MessageListResponse{ + "": { + {Data: []domain.Message{{ + ID: "new-1", + GrantID: "grant-123", + Subject: "Hello", + Snippet: "First line", + From: []domain.EmailParticipant{{Name: "Ada", Email: "ada@example.com"}}, + Date: time.Unix(101, 0), + Unread: true, + Folders: []string{"inbox"}, + }}}, + {}, + }, + }, + } + + var calls []notifyCall + poller := NewMessagePoller(client, "grant-123", 100, func(method string, params any) error { + calls = append(calls, notifyCall{method: method, params: params}) + return nil + }) + + if err := poller.PollOnce(context.Background()); err != nil { + t.Fatalf("PollOnce() error = %v", err) + } + if err := poller.PollOnce(context.Background()); err != nil { + t.Fatalf("PollOnce() second error = %v", err) + } + + assertNotifyIDs(t, calls, []string{"new-1"}) + gotPayload, ok := calls[0].params.(messageReceivedPayload) + if !ok { + t.Fatalf("payload type = %T, want messageReceivedPayload", calls[0].params) + } + wantPayload := messageReceivedPayload{ + ID: "new-1", + GrantID: "grant-123", + Subject: "Hello", + Snippet: "First line", + From: []domain.EmailParticipant{{Name: "Ada", Email: "ada@example.com"}}, + Date: 101, + Unread: true, + Folders: []string{"inbox"}, + } + if !reflect.DeepEqual(gotPayload, wantPayload) { + t.Fatalf("payload = %#v, want %#v", gotPayload, wantPayload) + } + + payloadJSON, err := json.Marshal(gotPayload) + if err != nil { + t.Fatalf("marshal payload: %v", err) + } + wantJSON := `{"id":"new-1","grant_id":"grant-123","subject":"Hello","snippet":"First line","from":[{"name":"Ada","email":"ada@example.com"}],"date":101,"unread":true,"folders":["inbox"]}` + if string(payloadJSON) != wantJSON { + t.Fatalf("payload JSON = %s, want %s", payloadJSON, wantJSON) + } + assertQueries(t, client.params, []wantQuery{ + {receivedAfter: 99, pageToken: ""}, + {receivedAfter: 100, pageToken: ""}, + }) + if poller.cursor != 101 { + t.Fatalf("cursor = %d, want 101", poller.cursor) + } +} + +func TestMessagePoller_PollOnce_ReturnsClientErrorFromLaterPage(t *testing.T) { + clientErr := errors.New("api unavailable") + client := &fakePollClient{ + pages: map[string][]domain.MessageListResponse{ + "": {{ + Data: []domain.Message{pollMessage("new", 101)}, + Pagination: domain.Pagination{NextCursor: "page-2", HasMore: true}, + }}, + }, + errs: map[string]error{"page-2": clientErr}, + } + called := false + poller := NewMessagePoller(client, "grant-123", 100, func(method string, params any) error { + called = true + return nil + }) + + err := poller.PollOnce(context.Background()) + if !errors.Is(err, clientErr) { + t.Fatalf("PollOnce() error = %v, want %v", err, clientErr) + } + if called { + t.Fatal("notify was called on client error") + } + if poller.cursor != 100 { + t.Fatalf("cursor = %d, want 100", poller.cursor) + } +} + +func TestMessagePoller_PollOnce_ReturnsErrorWhenPageDrainTruncates(t *testing.T) { + tests := []struct { + name string + }{ + {name: "more pages after cap"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pages := make(map[string][]domain.MessageListResponse, maxMessagePollPages) + pageToken := "" + for page := range maxMessagePollPages { + nextToken := fmt.Sprintf("page-%d", page+1) + pages[pageToken] = []domain.MessageListResponse{{ + Data: []domain.Message{pollMessage(fmt.Sprintf("msg-%03d", page), int64(101+page))}, + Pagination: domain.Pagination{NextCursor: nextToken, HasMore: true}, + }} + pageToken = nextToken + } + + client := &fakePollClient{pages: pages} + var calls []notifyCall + poller := NewMessagePoller(client, "grant-123", 100, func(method string, params any) error { + calls = append(calls, notifyCall{method: method, params: params}) + return nil + }) + + err := poller.PollOnce(context.Background()) + wantErr := fmt.Sprintf("message poll truncated at %d pages; not advancing cursor", maxMessagePollPages) + if err == nil || err.Error() != wantErr { + t.Fatalf("PollOnce() error = %v, want %q", err, wantErr) + } + if poller.cursor != 100 { + t.Fatalf("cursor = %d, want 100", poller.cursor) + } + if len(calls) != 0 { + t.Fatalf("notify calls = %#v, want none", calls) + } + }) + } +} + +func TestMessagePoller_PollOnce_ReturnsNotifyError(t *testing.T) { + notifyErr := errors.New("websocket closed") + client := &fakePollClient{ + pages: map[string][]domain.MessageListResponse{ + "": {{Data: []domain.Message{pollMessage("new", 101)}}}, + }, + } + poller := NewMessagePoller(client, "grant-123", 100, func(method string, params any) error { + return notifyErr + }) + + err := poller.PollOnce(context.Background()) + if !errors.Is(err, notifyErr) { + t.Fatalf("PollOnce() error = %v, want %v", err, notifyErr) + } +} + +func TestMessagePoller_Run_ReturnsContextErrorOnCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + poller := NewMessagePoller(&fakePollClient{}, "grant-123", 0, func(method string, params any) error { + t.Fatal("notify should not be called") + return nil + }) + + if err := poller.Run(ctx, time.Hour, nil); !errors.Is(err, context.Canceled) { + t.Fatalf("Run() error = %v, want %v", err, context.Canceled) + } +} + +func pollMessages(newest, oldest int64) []domain.Message { + var messages []domain.Message + for ts := newest; ts >= oldest; ts-- { + messages = append(messages, pollMessage(fmt.Sprintf("msg-%03d", ts), ts)) + } + return messages +} + +func pollMessage(id string, unix int64) domain.Message { + return domain.Message{ + ID: id, + GrantID: "grant-123", + Date: time.Unix(unix, 0), + } +} + +func filterMessagesAfter(messages []domain.Message, receivedAfter int64) []domain.Message { + var filtered []domain.Message + for _, msg := range messages { + if msg.Date.Unix() > receivedAfter { + filtered = append(filtered, msg) + } + } + return filtered +} + +func assertNotifyIDs(t *testing.T, calls []notifyCall, want []string) { + t.Helper() + var got []string + for i, call := range calls { + if call.method != "message.received" { + t.Fatalf("notify call %d method = %q, want message.received", i, call.method) + } + payload, ok := call.params.(messageReceivedPayload) + if !ok { + t.Fatalf("notify call %d payload type = %T, want messageReceivedPayload", i, call.params) + } + got = append(got, payload.ID) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("notify IDs = %#v, want %#v", got, want) + } +} + +func assertUniqueNotifyIDs(t *testing.T, calls []notifyCall) { + t.Helper() + seen := make(map[string]struct{}) + for _, call := range calls { + payload := call.params.(messageReceivedPayload) + if _, ok := seen[payload.ID]; ok { + t.Fatalf("duplicate notify ID %q", payload.ID) + } + seen[payload.ID] = struct{}{} + } +} + +type wantQuery struct { + receivedAfter int64 + pageToken string +} + +func assertQueries(t *testing.T, got []domain.MessageQueryParams, want []wantQuery) { + t.Helper() + if len(got) != len(want) { + t.Fatalf("query count = %d, want %d: %#v", len(got), len(want), got) + } + for i := range want { + if got[i].Limit != messagePollLimit || got[i].ReceivedAfter != want[i].receivedAfter || got[i].PageToken != want[i].pageToken { + t.Fatalf("query %d = %+v, want limit %d received_after %d page_token %q", i, got[i], messagePollLimit, want[i].receivedAfter, want[i].pageToken) + } + } +} diff --git a/internal/adapters/rpcserver/poller_threads.go b/internal/adapters/rpcserver/poller_threads.go new file mode 100644 index 0000000..88182ae --- /dev/null +++ b/internal/adapters/rpcserver/poller_threads.go @@ -0,0 +1,91 @@ +package rpcserver + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +const ( + threadPollLimit = 50 + maxThreadPollPages = 20 +) + +type ThreadPoller struct { + incrementalState + client ports.MessageClient + grantID string + notify NotifyFunc +} + +type threadUpdatedPayload struct { + ID string `json:"id"` + Subject string `json:"subject"` + LatestMessageReceivedDate int64 `json:"latest_message_received_date"` + Unread bool `json:"unread"` + MessageCount int `json:"message_count"` +} + +// NewThreadPoller polls threads newer than since. +func NewThreadPoller(client ports.MessageClient, grantID string, since int64, notify NotifyFunc) *ThreadPoller { + return &ThreadPoller{ + incrementalState: incrementalState{cursor: since}, + client: client, + grantID: grantID, + notify: notify, + } +} + +// PollOnce runs one polling cycle. +func (p *ThreadPoller) PollOnce(ctx context.Context) error { + return pollIncremental(ctx, &p.incrementalState, p.fetch, func(thread domain.Thread) int64 { + return thread.LatestMessageRecvDate.Unix() + }, func(thread domain.Thread) string { + return thread.ID + }, "thread.updated", func(thread domain.Thread) any { + return threadUpdatedPayload{ + ID: thread.ID, + Subject: thread.Subject, + LatestMessageReceivedDate: thread.LatestMessageRecvDate.Unix(), + Unread: thread.Unread, + MessageCount: len(thread.MessageIDs), + } + }, p.notify) +} + +func (p *ThreadPoller) fetch(ctx context.Context, queryAfter int64) ([]domain.Thread, error) { + var threads []domain.Thread + pageToken := "" + for page := range maxThreadPollPages { + resp, err := p.client.GetThreadsWithCursor(ctx, p.grantID, &domain.ThreadQueryParams{ + Limit: threadPollLimit, + PageToken: pageToken, + LatestMsgAfter: queryAfter, + }) + if err != nil { + return nil, err + } + if resp == nil { + return nil, errors.New("thread poll response is nil") + } + threads = append(threads, resp.Data...) + if resp.Pagination.NextCursor == "" || !resp.Pagination.HasMore { + break + } + if page == maxThreadPollPages-1 { + return nil, fmt.Errorf("thread poll truncated at %d pages; not advancing cursor", maxThreadPollPages) + } + pageToken = resp.Pagination.NextCursor + } + // ponytail: cap polling bursts at 20 pages; webhooks are the real fix for larger inbox spikes. + return threads, nil +} + +// Run polls until ctx is cancelled. +func (p *ThreadPoller) Run(ctx context.Context, interval time.Duration, onError func(error)) error { + return runTicker(ctx, interval, onError, p.PollOnce) +} diff --git a/internal/adapters/rpcserver/poller_threads_test.go b/internal/adapters/rpcserver/poller_threads_test.go new file mode 100644 index 0000000..3f39de0 --- /dev/null +++ b/internal/adapters/rpcserver/poller_threads_test.go @@ -0,0 +1,305 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "reflect" + "testing" + "time" + + "github.com/nylas/cli/internal/domain" +) + +type threadNotifyCall struct { + method string + params any +} + +func TestThreadPoller_PollOnce_EmitsNewThreads(t *testing.T) { + client := &fakeThreadClient{ + getThreadsWithCursor: func(ctx context.Context, grantID string, params *domain.ThreadQueryParams) (*domain.ThreadListResponse, error) { + if grantID != "grant-123" { + t.Fatalf("grantID = %q, want grant-123", grantID) + } + assertThreadQuery(t, params, 99, "") + return &domain.ThreadListResponse{Data: filterThreadsAfter([]domain.Thread{ + pollThread("thread-2", 102), + { + ID: "thread-1", + Subject: "Hello", + LatestMessageRecvDate: time.Unix(101, 0), + Unread: true, + MessageIDs: []string{"msg-1", "msg-2"}, + }, + }, params.LatestMsgAfter)}, nil + }, + } + + var calls []threadNotifyCall + poller := NewThreadPoller(client, "grant-123", 100, func(method string, params any) error { + calls = append(calls, threadNotifyCall{method: method, params: params}) + return nil + }) + + if err := poller.PollOnce(context.Background()); err != nil { + t.Fatalf("PollOnce() error = %v", err) + } + + assertThreadNotifyIDs(t, calls, []string{"thread-1", "thread-2"}) + assertUniqueThreadNotifyIDs(t, calls) + gotPayload, ok := calls[0].params.(threadUpdatedPayload) + if !ok { + t.Fatalf("payload type = %T, want threadUpdatedPayload", calls[0].params) + } + wantPayload := threadUpdatedPayload{ + ID: "thread-1", + Subject: "Hello", + LatestMessageReceivedDate: 101, + Unread: true, + MessageCount: 2, + } + if !reflect.DeepEqual(gotPayload, wantPayload) { + t.Fatalf("payload = %#v, want %#v", gotPayload, wantPayload) + } + + payloadJSON, err := json.Marshal(gotPayload) + if err != nil { + t.Fatalf("marshal payload: %v", err) + } + wantJSON := `{"id":"thread-1","subject":"Hello","latest_message_received_date":101,"unread":true,"message_count":2}` + if string(payloadJSON) != wantJSON { + t.Fatalf("payload JSON = %s, want %s", payloadJSON, wantJSON) + } + if poller.cursor != 102 { + t.Fatalf("cursor = %d, want 102", poller.cursor) + } +} + +func TestThreadPoller_PollOnce_DrainsPagesEmitsAllNewThreads(t *testing.T) { + pages := map[string]domain.ThreadListResponse{ + "": { + Data: pollThreads(175, 126), + Pagination: domain.Pagination{NextCursor: "page-2", HasMore: true}, + }, + "page-2": { + Data: pollThreads(125, 106), + Pagination: domain.Pagination{NextCursor: "page-3", HasMore: true}, + }, + "page-3": { + Data: pollThreads(105, 101), + }, + } + var queries []domain.ThreadQueryParams + client := &fakeThreadClient{ + getThreadsWithCursor: func(ctx context.Context, grantID string, params *domain.ThreadQueryParams) (*domain.ThreadListResponse, error) { + queries = append(queries, *params) + resp := pages[params.PageToken] + resp.Data = filterThreadsAfter(resp.Data, params.LatestMsgAfter) + return &resp, nil + }, + } + + var calls []threadNotifyCall + poller := NewThreadPoller(client, "grant-123", 100, func(method string, params any) error { + calls = append(calls, threadNotifyCall{method: method, params: params}) + return nil + }) + + if err := poller.PollOnce(context.Background()); err != nil { + t.Fatalf("PollOnce() error = %v", err) + } + + var wantIDs []string + for ts := int64(101); ts <= 175; ts++ { + wantIDs = append(wantIDs, fmt.Sprintf("thread-%03d", ts)) + } + assertThreadNotifyIDs(t, calls, wantIDs) + assertUniqueThreadNotifyIDs(t, calls) + assertThreadQueries(t, queries, []wantThreadQuery{ + {latestAfter: 99, pageToken: ""}, + {latestAfter: 99, pageToken: "page-2"}, + {latestAfter: 99, pageToken: "page-3"}, + }) + if poller.cursor != 175 { + t.Fatalf("cursor = %d, want 175", poller.cursor) + } +} + +func TestThreadPoller_PollOnce_EmitsSameSecondBoundaryThreadOnce(t *testing.T) { + responses := [][]domain.Thread{ + {pollThread("boundary-a", 105)}, + {pollThread("boundary-a", 105), pollThread("boundary-b", 105), pollThread("newer-c", 106)}, + {pollThread("boundary-a", 105), pollThread("boundary-b", 105), pollThread("newer-c", 106)}, + } + var queries []domain.ThreadQueryParams + client := &fakeThreadClient{ + getThreadsWithCursor: func(ctx context.Context, grantID string, params *domain.ThreadQueryParams) (*domain.ThreadListResponse, error) { + queries = append(queries, *params) + threads := filterThreadsAfter(responses[0], params.LatestMsgAfter) + responses = responses[1:] + return &domain.ThreadListResponse{Data: threads}, nil + }, + } + + var calls []threadNotifyCall + poller := NewThreadPoller(client, "grant-123", 100, func(method string, params any) error { + calls = append(calls, threadNotifyCall{method: method, params: params}) + return nil + }) + + for i := 0; i < 3; i++ { + if err := poller.PollOnce(context.Background()); err != nil { + t.Fatalf("PollOnce() #%d error = %v", i+1, err) + } + } + + assertThreadNotifyIDs(t, calls, []string{"boundary-a", "boundary-b", "newer-c"}) + assertUniqueThreadNotifyIDs(t, calls) + assertThreadQueries(t, queries, []wantThreadQuery{ + {latestAfter: 99}, + {latestAfter: 104}, + {latestAfter: 105}, + }) +} + +func TestThreadPoller_PollOnce_NoNewThreadsEmitsNothing(t *testing.T) { + client := &fakeThreadClient{ + getThreadsWithCursor: func(ctx context.Context, grantID string, params *domain.ThreadQueryParams) (*domain.ThreadListResponse, error) { + assertThreadQuery(t, params, 99, "") + return &domain.ThreadListResponse{}, nil + }, + } + poller := NewThreadPoller(client, "grant-123", 100, func(method string, params any) error { + t.Fatal("notify should not be called") + return nil + }) + + if err := poller.PollOnce(context.Background()); err != nil { + t.Fatalf("PollOnce() error = %v", err) + } + if poller.cursor != 100 { + t.Fatalf("cursor = %d, want 100", poller.cursor) + } +} + +func TestThreadPoller_PollOnce_ReturnsClientError(t *testing.T) { + clientErr := errors.New("api unavailable") + client := &fakeThreadClient{ + getThreadsWithCursor: func(ctx context.Context, grantID string, params *domain.ThreadQueryParams) (*domain.ThreadListResponse, error) { + return nil, clientErr + }, + } + called := false + poller := NewThreadPoller(client, "grant-123", 100, func(method string, params any) error { + called = true + return nil + }) + + err := poller.PollOnce(context.Background()) + if !errors.Is(err, clientErr) { + t.Fatalf("PollOnce() error = %v, want %v", err, clientErr) + } + if called { + t.Fatal("notify was called on client error") + } + if poller.cursor != 100 { + t.Fatalf("cursor = %d, want 100", poller.cursor) + } +} + +func TestThreadPoller_Run_ReturnsContextErrorOnCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + poller := NewThreadPoller(&fakeThreadClient{}, "grant-123", 0, func(method string, params any) error { + t.Fatal("notify should not be called") + return nil + }) + + if err := poller.Run(ctx, time.Hour, nil); !errors.Is(err, context.Canceled) { + t.Fatalf("Run() error = %v, want %v", err, context.Canceled) + } +} + +func pollThread(id string, unix int64) domain.Thread { + return domain.Thread{ + ID: id, + Subject: id, + LatestMessageRecvDate: time.Unix(unix, 0), + } +} + +func pollThreads(newest, oldest int64) []domain.Thread { + var threads []domain.Thread + for ts := newest; ts >= oldest; ts-- { + threads = append(threads, pollThread(fmt.Sprintf("thread-%03d", ts), ts)) + } + return threads +} + +func filterThreadsAfter(threads []domain.Thread, latestMsgAfter int64) []domain.Thread { + var filtered []domain.Thread + for _, thread := range threads { + if thread.LatestMessageRecvDate.Unix() > latestMsgAfter { + filtered = append(filtered, thread) + } + } + return filtered +} + +func assertThreadNotifyIDs(t *testing.T, calls []threadNotifyCall, want []string) { + t.Helper() + var got []string + for i, call := range calls { + if call.method != "thread.updated" { + t.Fatalf("notify call %d method = %q, want thread.updated", i, call.method) + } + payload, ok := call.params.(threadUpdatedPayload) + if !ok { + t.Fatalf("notify call %d payload type = %T, want threadUpdatedPayload", i, call.params) + } + got = append(got, payload.ID) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("notify IDs = %#v, want %#v", got, want) + } +} + +func assertUniqueThreadNotifyIDs(t *testing.T, calls []threadNotifyCall) { + t.Helper() + seen := make(map[string]struct{}) + for _, call := range calls { + payload := call.params.(threadUpdatedPayload) + if _, ok := seen[payload.ID]; ok { + t.Fatalf("duplicate notify ID %q", payload.ID) + } + seen[payload.ID] = struct{}{} + } +} + +func assertThreadQuery(t *testing.T, got *domain.ThreadQueryParams, latestAfter int64, pageToken string) { + t.Helper() + if got == nil { + t.Fatal("query params = nil") + } + if got.Limit != threadPollLimit || got.LatestMsgAfter != latestAfter || got.PageToken != pageToken { + t.Fatalf("query = %+v, want limit %d latest_message_after %d page_token %q", got, threadPollLimit, latestAfter, pageToken) + } +} + +type wantThreadQuery struct { + latestAfter int64 + pageToken string +} + +func assertThreadQueries(t *testing.T, got []domain.ThreadQueryParams, want []wantThreadQuery) { + t.Helper() + if len(got) != len(want) { + t.Fatalf("query count = %d, want %d: %#v", len(got), len(want), got) + } + for i, want := range want { + assertThreadQuery(t, &got[i], want.latestAfter, want.pageToken) + } +} diff --git a/internal/adapters/rpcserver/server.go b/internal/adapters/rpcserver/server.go new file mode 100644 index 0000000..a3f6dc4 --- /dev/null +++ b/internal/adapters/rpcserver/server.go @@ -0,0 +1,208 @@ +package rpcserver + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +const shutdownTimeout = 5 * time.Second + +type Config struct { + Addr string + Token string + AllowedOrigins []string +} + +type Server struct { + dispatcher *Dispatcher + cfg Config + upgrader websocket.Upgrader + baseCtx context.Context + + mu sync.Mutex + conns map[*websocket.Conn]*clientConn +} + +type clientConn struct { + conn *websocket.Conn + writeMu sync.Mutex +} + +func NewServer(cfg Config, d *Dispatcher) *Server { + return &Server{ + dispatcher: d, + cfg: cfg, + upgrader: websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return ValidateOrigin(r.Header.Get("Origin"), cfg.AllowedOrigins) + }, + }, + conns: make(map[*websocket.Conn]*clientConn), + } +} + +// Broadcast writes a JSON-RPC notification to every connected client. +func (s *Server) Broadcast(method string, params any) error { + msg, err := NewNotification(method, params) + if err != nil { + return fmt.Errorf("create notification: %w", err) + } + + s.mu.Lock() + conns := make([]*clientConn, 0, len(s.conns)) + for _, c := range s.conns { + conns = append(conns, c) + } + s.mu.Unlock() + + for _, c := range conns { + c.writeMu.Lock() + err := c.conn.WriteMessage(websocket.TextMessage, msg) + c.writeMu.Unlock() + if err != nil { + s.unregister(c) + } + } + + return nil +} + +// Serve starts the WebSocket server and blocks until ctx is cancelled or the server fails. +func (s *Server) Serve(ctx context.Context) error { + s.mu.Lock() + s.baseCtx = ctx + s.mu.Unlock() + + httpServer := &http.Server{ + Addr: s.cfg.Addr, + Handler: s.handler(), + ReadHeaderTimeout: 5 * time.Second, + } + + errCh := make(chan error, 1) + go func() { + err := httpServer.ListenAndServe() + if errors.Is(err, http.ErrServerClosed) { + errCh <- nil + return + } + errCh <- err + }() + + select { + case err := <-errCh: + return err + case <-ctx.Done(): + shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) + defer cancel() + if err := httpServer.Shutdown(shutdownCtx); err != nil { + s.closeConns() + return fmt.Errorf("shutdown rpc server: %w", err) + } + s.closeConns() + return <-errCh + } +} + +func (s *Server) handler() http.Handler { + mux := http.NewServeMux() + mux.HandleFunc("/ws", s.handleWebSocket) + return mux +} + +func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) { + if !s.authorized(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if !ValidateOrigin(r.Header.Get("Origin"), s.cfg.AllowedOrigins) { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + + conn, err := s.upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + + c := &clientConn{conn: conn} + s.register(c) + defer s.unregister(c) + + s.mu.Lock() + baseCtx := s.baseCtx + s.mu.Unlock() + if baseCtx == nil { + baseCtx = context.Background() + } + connCtx, cancel := context.WithCancel(baseCtx) + defer cancel() + + for { + _, msg, err := conn.ReadMessage() + if err != nil { + return + } + resp := s.dispatcher.Dispatch(connCtx, msg) + if resp == nil { + continue + } + + c.writeMu.Lock() + err = conn.WriteMessage(websocket.TextMessage, resp) + c.writeMu.Unlock() + if err != nil { + return + } + } +} + +func (s *Server) authorized(r *http.Request) bool { + if token := bearerToken(r.Header.Get("Authorization")); ValidateToken(s.cfg.Token, token) { + return true + } + return ValidateToken(s.cfg.Token, r.URL.Query().Get("token")) +} + +func bearerToken(header string) string { + token, ok := strings.CutPrefix(header, "Bearer ") + if !ok { + return "" + } + return token +} + +func (s *Server) register(c *clientConn) { + s.mu.Lock() + defer s.mu.Unlock() + s.conns[c.conn] = c +} + +func (s *Server) unregister(c *clientConn) { + s.mu.Lock() + defer s.mu.Unlock() + if _, ok := s.conns[c.conn]; ok { + delete(s.conns, c.conn) + _ = c.conn.Close() + } +} + +func (s *Server) closeConns() { + s.mu.Lock() + conns := make([]*clientConn, 0, len(s.conns)) + for _, c := range s.conns { + conns = append(conns, c) + } + s.mu.Unlock() + + for _, c := range conns { + s.unregister(c) + } +} diff --git a/internal/adapters/rpcserver/server_test.go b/internal/adapters/rpcserver/server_test.go new file mode 100644 index 0000000..9f847f0 --- /dev/null +++ b/internal/adapters/rpcserver/server_test.go @@ -0,0 +1,241 @@ +package rpcserver + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/gorilla/websocket" +) + +func TestServer_WebSocketAuthDispatchAndBroadcast(t *testing.T) { + d := NewDispatcher() + d.Register("echo", func(ctx context.Context, params json.RawMessage) (any, error) { + var p struct { + Message string `json:"message"` + } + if err := json.Unmarshal(params, &p); err != nil { + return nil, err + } + return map[string]string{"message": p.Message}, nil + }) + + srv := NewServer(Config{Token: "secret-token"}, d) + httpSrv := newHTTPTestServer(t, srv.handler()) + t.Cleanup(httpSrv.Close) + wsURL := "ws" + strings.TrimPrefix(httpSrv.URL, "http") + "/ws" + + tests := []struct { + name string + token string + wantStatus int + }{ + {name: "missing token rejected", wantStatus: http.StatusUnauthorized}, + {name: "wrong token rejected", token: "wrong-token", wantStatus: http.StatusUnauthorized}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, resp, err := websocket.DefaultDialer.Dial(wsURL, authHeader(tt.token)) + if err == nil { + t.Fatal("Dial() error = nil, want handshake failure") + } + if resp == nil { + t.Fatal("Dial() response = nil, want HTTP response") + } + defer func() { + _ = resp.Body.Close() + }() + if resp.StatusCode != tt.wantStatus { + t.Fatalf("status = %d, want %d", resp.StatusCode, tt.wantStatus) + } + }) + } + + conn, resp, err := websocket.DefaultDialer.Dial(wsURL, authHeader("secret-token")) + if err != nil { + t.Fatalf("Dial() error = %v", err) + } + if resp != nil { + defer func() { + _ = resp.Body.Close() + }() + } + t.Cleanup(func() { + _ = conn.Close() + }) + + if err := conn.WriteMessage(websocket.TextMessage, []byte(`{"jsonrpc":"2.0","id":1,"method":"echo","params":{"message":"hi"}}`)); err != nil { + t.Fatalf("WriteMessage() error = %v", err) + } + + var response struct { + JSONRPC string `json:"jsonrpc"` + ID int `json:"id"` + Result struct { + Message string `json:"message"` + } `json:"result"` + } + readJSON(t, conn, &response) + if response.JSONRPC != "2.0" || response.ID != 1 || response.Result.Message != "hi" { + t.Fatalf("response = %+v, want echo result", response) + } + + if err := srv.Broadcast("message.received", map[string]string{"id": "msg-1"}); err != nil { + t.Fatalf("Broadcast() error = %v", err) + } + + var notification struct { + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + Params struct { + ID string `json:"id"` + } `json:"params"` + } + readJSON(t, conn, ¬ification) + if notification.JSONRPC != "2.0" || notification.Method != "message.received" || notification.Params.ID != "msg-1" { + t.Fatalf("notification = %+v, want message.received msg-1", notification) + } +} + +func TestServer_ConcurrentClientWritesAndBroadcast(t *testing.T) { + d := NewDispatcher() + d.Register("echo", func(ctx context.Context, params json.RawMessage) (any, error) { + time.Sleep(2 * time.Millisecond) + return json.RawMessage(params), nil + }) + + srv := NewServer(Config{Token: "secret-token"}, d) + httpSrv := newHTTPTestServer(t, srv.handler()) + t.Cleanup(httpSrv.Close) + wsURL := "ws" + strings.TrimPrefix(httpSrv.URL, "http") + "/ws" + + conn, resp, err := websocket.DefaultDialer.Dial(wsURL, authHeader("secret-token")) + if err != nil { + t.Fatalf("Dial() error = %v", err) + } + if resp != nil { + defer func() { + _ = resp.Body.Close() + }() + } + t.Cleanup(func() { + _ = conn.Close() + }) + if err := conn.SetReadDeadline(time.Now().Add(3 * time.Second)); err != nil { + t.Fatalf("SetReadDeadline() error = %v", err) + } + + broadcastReceived := make(chan struct{}) + readErr := make(chan error, 1) + go func() { + for { + _, msg, err := conn.ReadMessage() + if err != nil { + readErr <- err + return + } + var envelope struct { + Method string `json:"method"` + } + if err := json.Unmarshal(msg, &envelope); err != nil { + readErr <- fmt.Errorf("unmarshal %s: %w", msg, err) + return + } + if envelope.Method == "message.received" { + close(broadcastReceived) + return + } + } + }() + + const writes = 25 + var wg sync.WaitGroup + wg.Add(2) + errs := make(chan error, 2) + + go func() { + defer wg.Done() + for i := range writes { + msg := fmt.Sprintf(`{"jsonrpc":"2.0","id":%d,"method":"echo","params":{"i":%d}}`, i, i) + if err := conn.WriteMessage(websocket.TextMessage, []byte(msg)); err != nil { + errs <- fmt.Errorf("write request %d: %w", i, err) + return + } + } + }() + + go func() { + defer wg.Done() + for i := range writes { + if err := srv.Broadcast("message.received", map[string]int{"i": i}); err != nil { + errs <- fmt.Errorf("broadcast %d: %w", i, err) + return + } + } + }() + + wg.Wait() + close(errs) + for err := range errs { + if err != nil { + t.Fatal(err) + } + } + + select { + case <-broadcastReceived: + case err := <-readErr: + t.Fatalf("read message: %v", err) + case <-time.After(3 * time.Second): + t.Fatal("timed out waiting for broadcast") + } +} + +func authHeader(token string) http.Header { + h := http.Header{} + if token != "" { + h.Set("Authorization", "Bearer "+token) + } + return h +} + +func newHTTPTestServer(t *testing.T, h http.Handler) *httptest.Server { + t.Helper() + + var srv *httptest.Server + func() { + defer func() { + if r := recover(); r != nil { + msg := fmt.Sprint(r) + if strings.Contains(msg, "httptest: failed to listen on a port") && strings.Contains(msg, "operation not permitted") { + t.Skipf("local TCP listener unavailable in this sandbox: %v", r) + } + panic(r) + } + }() + srv = httptest.NewServer(h) + }() + return srv +} + +func readJSON(t *testing.T, conn *websocket.Conn, v any) { + t.Helper() + + if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Fatalf("SetReadDeadline() error = %v", err) + } + _, msg, err := conn.ReadMessage() + if err != nil { + t.Fatalf("ReadMessage() error = %v", err) + } + if err := json.Unmarshal(msg, v); err != nil { + t.Fatalf("unmarshal %s: %v", msg, err) + } +} diff --git a/internal/cli/agent/rule_validation.go b/internal/cli/agent/rule_validation.go index 23cdef3..edb00e7 100644 --- a/internal/cli/agent/rule_validation.go +++ b/internal/cli/agent/rule_validation.go @@ -266,7 +266,7 @@ func validateRuleActions(actions []domain.RuleAction) error { if _, ok := scalarRuleValue(action.Value); !ok { return common.NewUserError( "assign_to_folder requires a folder value", - "Use --action assign_to_folder=", + "Use --action assign_to_folder=", ) } continue diff --git a/internal/cli/integration/rpc_ext_smoke_test.go b/internal/cli/integration/rpc_ext_smoke_test.go new file mode 100644 index 0000000..4b42c1c --- /dev/null +++ b/internal/cli/integration/rpc_ext_smoke_test.go @@ -0,0 +1,78 @@ +//go:build integration +// +build integration + +package integration + +import "testing" + +// TestCLI_RPC_ExtSmoke proves the newly added email/calendar/contacts/auth handlers +// (RegisterEmailExtHandlers / RegisterCalendarExtHandlers / RegisterContactExtHandlers) +// are wired into the running `nylas rpc serve` binary and reachable end-to-end over the +// WebSocket transport. It is intentionally READ-ONLY — no state is mutated, so there is +// nothing to clean up. Per-method live behavior of the underlying client methods is +// already covered by the CLI integration suite; this only guards the RPC wiring. +func TestCLI_RPC_ExtSmoke(t *testing.T) { + skipIfMissingCreds(t) + + addr, tok := startRPCServer(t, nil) + conn := dialRPC(t, addr, tok) + id := 1 + + t.Run("email.folder.list", func(t *testing.T) { + acquireRateLimit(t) + res := rpcCall(t, conn, id, "email.folder.list", map[string]any{ + "grant_id": testGrantID, + }) + id++ + if res.IsError { + t.Fatalf("email.folder.list returned RPC error %d %q", res.ErrCode, res.ErrMsg) + } + if _, ok := res.Result["folders"]; !ok { + t.Fatal("email.folder.list result missing folders key") + } + }) + + t.Run("email.signature.list", func(t *testing.T) { + acquireRateLimit(t) + res := rpcCall(t, conn, id, "email.signature.list", map[string]any{ + "grant_id": testGrantID, + }) + id++ + if res.IsError { + t.Fatalf("email.signature.list returned RPC error %d %q", res.ErrCode, res.ErrMsg) + } + if _, ok := res.Result["signatures"]; !ok { + t.Fatal("email.signature.list result missing signatures key") + } + }) + + t.Run("calendar.get primary", func(t *testing.T) { + acquireRateLimit(t) + res := rpcCall(t, conn, id, "calendar.get", map[string]any{ + "grant_id": testGrantID, + "calendar_id": "primary", + }) + id++ + if res.IsError { + t.Fatalf("calendar.get returned RPC error %d %q", res.ErrCode, res.ErrMsg) + } + if _, ok := res.Result["id"]; !ok { + t.Fatal("calendar.get result missing id key") + } + }) + + t.Run("contact.group.list", func(t *testing.T) { + acquireRateLimit(t) + res := rpcCall(t, conn, id, "contact.group.list", map[string]any{ + "grant_id": testGrantID, + }) + id++ + if res.IsError { + t.Fatalf("contact.group.list returned RPC error %d %q", res.ErrCode, res.ErrMsg) + } + // Some providers (e.g. Google) return groups; the key must always be present. + if _, ok := res.Result["groups"]; !ok { + t.Fatal("contact.group.list result missing groups key") + } + }) +} diff --git a/internal/cli/integration/rpc_extended_test.go b/internal/cli/integration/rpc_extended_test.go new file mode 100644 index 0000000..8a4b200 --- /dev/null +++ b/internal/cli/integration/rpc_extended_test.go @@ -0,0 +1,78 @@ +//go:build integration +// +build integration + +package integration + +import "testing" + +// TestCLI_RPC_ExtendedReads exercises the read methods of the extended domains +// (draft, notetaker, scheduler, template, workflow, admin, workspace, audit, auth, otp). +// +// Local/pure methods (audit.*, auth.url) must succeed outright. API-backed methods +// (admin/scheduler/template/...) only need to RESPOND end-to-end — a success OR a +// well-formed RPC error both prove the method is wired and the handler ran; the test +// tolerates permission/empty-resource errors since the test account may not have +// admin access or any scheduler configs. Deep CRUD of these domains is covered by the +// unit tests (creating real connectors/configs/credentials live is impractical/unsafe). +func TestCLI_RPC_ExtendedReads(t *testing.T) { + skipIfMissingCreds(t) + addr, tok := startRPCServer(t, nil) + conn := dialRPC(t, addr, tok) + + id := 0 + nextID := func() int { id++; return id } + + // Group 1: local / pure builders — must succeed, with the expected result key. + local := []struct { + name, method, key string + params map[string]any + }{ + {"audit.list", "audit.list", "entries", map[string]any{"limit": 5}}, + {"audit.stats", "audit.stats", "file_count", nil}, + {"audit.path", "audit.path", "path", nil}, + {"audit.config.read", "audit.config.read", "", nil}, + {"auth.url", "auth.url", "url", map[string]any{"provider": "google", "redirect_uri": "http://localhost/callback"}}, + } + for _, tc := range local { + t.Run(tc.name, func(t *testing.T) { + acquireRateLimit(t) + res := rpcCall(t, conn, nextID(), tc.method, tc.params) + if res.IsError { + t.Fatalf("%s unexpected RPC error %d %q", tc.method, res.ErrCode, res.ErrMsg) + } + if tc.key != "" { + if _, ok := res.Result[tc.key]; !ok { + t.Fatalf("%s result missing key %q; got %v", tc.method, tc.key, res.Result) + } + } + }) + } + + // Group 2: API-backed reads — must respond end-to-end; success or a clean RPC error are both OK. + apiBacked := []struct { + name, method string + params map[string]any + }{ + {"draft.list", "draft.list", map[string]any{"grant_id": testGrantID, "limit": 2}}, + {"notetaker.list", "notetaker.list", map[string]any{"grant_id": testGrantID}}, + {"scheduler.config.list", "scheduler.config.list", nil}, + {"template.list", "template.list", map[string]any{"scope": "app"}}, + {"workflow.list", "workflow.list", map[string]any{"scope": "app"}}, + {"admin.app.list", "admin.app.list", nil}, + {"admin.connector.list", "admin.connector.list", nil}, + {"workspace.list", "workspace.list", nil}, + {"auth.grant.get", "auth.grant.get", map[string]any{"grant_id": testGrantID}}, + } + for _, tc := range apiBacked { + t.Run(tc.name, func(t *testing.T) { + acquireRateLimit(t) + // rpcCall fails the test only on a transport/read failure; a returned result OR a + // structured RPC error both mean the method is wired and reachable. + res := rpcCall(t, conn, nextID(), tc.method, tc.params) + if res.IsError { + t.Logf("%s returned RPC error %d %q (acceptable — account may lack permission/resources)", + tc.method, res.ErrCode, res.ErrMsg) + } + }) + } +} diff --git a/internal/cli/integration/rpc_notifications_test.go b/internal/cli/integration/rpc_notifications_test.go new file mode 100644 index 0000000..fc0b9d5 --- /dev/null +++ b/internal/cli/integration/rpc_notifications_test.go @@ -0,0 +1,58 @@ +//go:build integration +// +build integration + +package integration + +import ( + "encoding/json" + "fmt" + "os" + "strings" + "testing" + "time" +) + +func TestCLI_RPC_Notification_MessageReceived(t *testing.T) { + skipIfMissingCreds(t) + + recipient := strings.TrimSpace(getTestEmail()) + if recipient == "" { + t.Skip("no test email configured") + } + + addr, tok := startRPCServer(t, map[string]string{"NYLAS_GRANT_ID": testGrantID}) + conn := dialRPC(t, addr, tok) + + marker := fmt.Sprintf("RPC-IT-%d-%d", os.Getpid(), time.Now().UnixNano()) + subject := marker + " notification test" + + time.Sleep(2 * time.Second) + acquireRateLimit(t) + + stdout, stderr, err := runCLIWithOverrides(2*time.Minute, map[string]string{"NYLAS_GRANT_ID": testGrantID}, + "email", "send", + "-t", recipient, + "-s", subject, + "-b", "rpc integration notification test", + "-y", + ) + if err != nil { + t.Fatalf("email send failed: %v\nstdout: %s\nstderr: %s", err, stdout, stderr) + } + + params, ok := waitForNotification(t, conn, "message.received", 90*time.Second, func(p json.RawMessage) bool { + var msg struct { + Subject string `json:"subject"` + } + if err := json.Unmarshal(p, &msg); err != nil { + return false + } + return strings.Contains(msg.Subject, marker) + }) + if !ok { + t.Fatalf("timed out waiting for message.received notification with marker %q", marker) + } + if len(params) == 0 { + t.Fatal("message.received notification returned empty params") + } +} diff --git a/internal/cli/integration/rpc_protocol_test.go b/internal/cli/integration/rpc_protocol_test.go new file mode 100644 index 0000000..d016903 --- /dev/null +++ b/internal/cli/integration/rpc_protocol_test.go @@ -0,0 +1,212 @@ +//go:build integration +// +build integration + +package integration + +import ( + "encoding/json" + "errors" + "net" + "net/http" + "strings" + "testing" + "time" + + "github.com/gorilla/websocket" +) + +func TestCLI_RPC_Auth_MissingToken(t *testing.T) { + skipIfMissingCreds(t) + addr, _ := startRPCServer(t, map[string]string{"NYLAS_GRANT_ID": ""}) + + conn, resp, err := websocket.DefaultDialer.Dial("ws://"+addr+"/ws", nil) + if conn != nil { + _ = conn.Close() + } + if resp != nil && resp.Body != nil { + defer func() { _ = resp.Body.Close() }() + } + if err == nil { + t.Fatal("expected websocket dial without token to fail") + } + if resp != nil && resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusUnauthorized) + } +} + +func TestCLI_RPC_Auth_WrongToken(t *testing.T) { + skipIfMissingCreds(t) + addr, _ := startRPCServer(t, map[string]string{"NYLAS_GRANT_ID": ""}) + + conn, resp, err := websocket.DefaultDialer.Dial("ws://"+addr+"/ws", + http.Header{"Authorization": {"Bearer wrong"}}) + if conn != nil { + _ = conn.Close() + } + if resp != nil && resp.Body != nil { + defer func() { _ = resp.Body.Close() }() + } + if err == nil { + t.Fatal("expected websocket dial with wrong token to fail") + } + if resp != nil && resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusUnauthorized) + } +} + +func TestCLI_RPC_Auth_QueryParamToken(t *testing.T) { + skipIfMissingCreds(t) + addr, _ := startRPCServer(t, map[string]string{"NYLAS_GRANT_ID": ""}) + + conn, resp, err := websocket.DefaultDialer.Dial("ws://"+addr+"/ws?token="+rpcTestToken, nil) + if resp != nil && resp.Body != nil { + defer func() { _ = resp.Body.Close() }() + } + if err != nil { + t.Fatalf("dial with query token: %v", err) + } + _ = conn.Close() +} + +func TestCLI_RPC_Auth_OriginRejected(t *testing.T) { + skipIfMissingCreds(t) + addr, tok := startRPCServer(t, map[string]string{"NYLAS_GRANT_ID": ""}) + + conn, resp, err := websocket.DefaultDialer.Dial("ws://"+addr+"/ws", http.Header{ + "Authorization": {"Bearer " + tok}, + "Origin": {"http://evil.example"}, + }) + if conn != nil { + _ = conn.Close() + } + if resp != nil && resp.Body != nil { + defer func() { _ = resp.Body.Close() }() + } + if err == nil { + t.Fatal("expected websocket dial with cross-origin header to fail") + } + if resp == nil { + t.Fatal("expected HTTP response for rejected origin") + } + if resp.StatusCode != http.StatusForbidden { + t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusForbidden) + } +} + +func TestCLI_RPC_NonLoopbackRefused(t *testing.T) { + if testBinary == "" { + t.Skip("no binary") + } + + _, stderr, err := runCLI("rpc", "serve", "--addr", "0.0.0.0:12345") + if err == nil { + t.Fatal("expected non-loopback bind without --allow-remote to fail") + } + if !strings.Contains(stderr, "refusing to bind") { + t.Fatalf("stderr = %q, want substring %q", stderr, "refusing to bind") + } +} + +func TestCLI_RPC_UnknownMethod(t *testing.T) { + skipIfMissingCreds(t) + addr, tok := startRPCServer(t, map[string]string{"NYLAS_GRANT_ID": ""}) + conn := dialRPC(t, addr, tok) + + result := rpcCall(t, conn, 1, "does.not.exist", nil) + if !result.IsError || result.ErrCode != -32601 { + t.Fatalf("rpc error = (%v, %d), want (true, -32601)", result.IsError, result.ErrCode) + } +} + +func TestCLI_RPC_MissingRequiredParam(t *testing.T) { + skipIfMissingCreds(t) + addr, tok := startRPCServer(t, map[string]string{"NYLAS_GRANT_ID": ""}) + conn := dialRPC(t, addr, tok) + + result := rpcCall(t, conn, 2, "email.get", map[string]any{"grant_id": "x"}) + if !result.IsError || result.ErrCode != -32602 { + t.Fatalf("rpc error = (%v, %d), want (true, -32602)", result.IsError, result.ErrCode) + } +} + +func TestCLI_RPC_MalformedJSON(t *testing.T) { + skipIfMissingCreds(t) + addr, tok := startRPCServer(t, map[string]string{"NYLAS_GRANT_ID": ""}) + conn := dialRPC(t, addr, tok) + + if err := conn.WriteMessage(websocket.TextMessage, []byte("{not valid json")); err != nil { + t.Fatalf("write malformed json: %v", err) + } + if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { + t.Fatalf("set read deadline: %v", err) + } + _, msg, err := conn.ReadMessage() + if err != nil { + t.Fatalf("read parse error response: %v", err) + } + var resp struct { + Error struct { + Code int `json:"code"` + } `json:"error"` + ID json.RawMessage `json:"id"` + } + if err := json.Unmarshal(msg, &resp); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + if resp.Error.Code != -32700 { + t.Fatalf("error.code = %d, want -32700", resp.Error.Code) + } + if string(resp.ID) != "null" { + t.Fatalf("id = %s, want null", resp.ID) + } +} + +func TestCLI_RPC_BadVersion(t *testing.T) { + skipIfMissingCreds(t) + addr, tok := startRPCServer(t, map[string]string{"NYLAS_GRANT_ID": ""}) + conn := dialRPC(t, addr, tok) + + if err := conn.WriteMessage(websocket.TextMessage, []byte(`{"id":1,"method":"email.list"}`)); err != nil { + t.Fatalf("write bad-version request: %v", err) + } + if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { + t.Fatalf("set read deadline: %v", err) + } + _, msg, err := conn.ReadMessage() + if err != nil { + t.Fatalf("read bad-version response: %v", err) + } + var resp struct { + Error struct { + Code int `json:"code"` + } `json:"error"` + } + if err := json.Unmarshal(msg, &resp); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + if resp.Error.Code != -32600 { + t.Fatalf("error.code = %d, want -32600", resp.Error.Code) + } +} + +func TestCLI_RPC_Notification_NoReply(t *testing.T) { + skipIfMissingCreds(t) + addr, tok := startRPCServer(t, map[string]string{"NYLAS_GRANT_ID": ""}) + conn := dialRPC(t, addr, tok) + + if err := conn.WriteJSON(map[string]any{ + "jsonrpc": "2.0", + "method": "client.focus", + "params": map[string]any{"focused": true}, + }); err != nil { + t.Fatalf("write notification: %v", err) + } + if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Fatalf("set read deadline: %v", err) + } + _, _, err := conn.ReadMessage() + var netErr net.Error + if !errors.As(err, &netErr) || !netErr.Timeout() { + t.Fatalf("ReadMessage error = %v, want timeout", err) + } +} diff --git a/internal/cli/integration/rpc_reads_test.go b/internal/cli/integration/rpc_reads_test.go new file mode 100644 index 0000000..c6f668e --- /dev/null +++ b/internal/cli/integration/rpc_reads_test.go @@ -0,0 +1,224 @@ +//go:build integration +// +build integration + +package integration + +import ( + "encoding/json" + "testing" +) + +func TestCLI_RPC_Reads(t *testing.T) { + skipIfMissingCreds(t) + + addr, tok := startRPCServer(t, nil) + conn := dialRPC(t, addr, tok) + id := 1 + + t.Run("email.list", func(t *testing.T) { + acquireRateLimit(t) + res := rpcCall(t, conn, id, "email.list", map[string]any{ + "grant_id": testGrantID, + "limit": 2, + }) + id++ + if res.IsError { + t.Fatalf("email.list returned RPC error %d %q", res.ErrCode, res.ErrMsg) + } + if _, ok := res.Result["messages"]; !ok { + t.Fatal("email.list result missing messages key") + } + var nextCursor string + if raw, ok := res.Result["next_cursor"]; ok && string(raw) != "null" { + if err := json.Unmarshal(raw, &nextCursor); err != nil { + t.Fatalf("unmarshal next_cursor: %v", err) + } + } + }) + + t.Run("email.get", func(t *testing.T) { + acquireRateLimit(t) + listRes := rpcCall(t, conn, id, "email.list", map[string]any{ + "grant_id": testGrantID, + "limit": 2, + }) + id++ + if listRes.IsError { + t.Fatalf("email.list returned RPC error %d %q", listRes.ErrCode, listRes.ErrMsg) + } + + var messages []struct { + ID string `json:"id"` + } + if err := json.Unmarshal(listRes.Result["messages"], &messages); err != nil { + t.Fatalf("unmarshal messages: %v", err) + } + if len(messages) == 0 { + t.Skip("no messages available") + } + + acquireRateLimit(t) + res := rpcCall(t, conn, id, "email.get", map[string]any{ + "grant_id": testGrantID, + "message_id": messages[0].ID, + }) + id++ + if res.IsError { + t.Fatalf("email.get returned RPC error %d %q", res.ErrCode, res.ErrMsg) + } + if _, ok := res.Result["id"]; !ok { + t.Fatal("email.get result missing id key") + } + }) + + t.Run("thread.list", func(t *testing.T) { + acquireRateLimit(t) + res := rpcCall(t, conn, id, "thread.list", map[string]any{ + "grant_id": testGrantID, + "limit": 2, + }) + id++ + if res.IsError { + t.Fatalf("thread.list returned RPC error %d %q", res.ErrCode, res.ErrMsg) + } + if _, ok := res.Result["threads"]; !ok { + t.Fatal("thread.list result missing threads key") + } + }) + + t.Run("calendar.list", func(t *testing.T) { + acquireRateLimit(t) + res := rpcCall(t, conn, id, "calendar.list", map[string]any{ + "grant_id": testGrantID, + }) + id++ + if res.IsError { + t.Fatalf("calendar.list returned RPC error %d %q", res.ErrCode, res.ErrMsg) + } + if _, ok := res.Result["calendars"]; !ok { + t.Fatal("calendar.list result missing calendars key") + } + }) + + t.Run("event.list", func(t *testing.T) { + acquireRateLimit(t) + res := rpcCall(t, conn, id, "event.list", map[string]any{ + "grant_id": testGrantID, + "calendar_id": "primary", + "limit": 2, + }) + id++ + if res.IsError { + t.Fatalf("event.list returned RPC error %d %q", res.ErrCode, res.ErrMsg) + } + if _, ok := res.Result["events"]; !ok { + t.Fatal("event.list result missing events key") + } + }) + + t.Run("contact.list", func(t *testing.T) { + acquireRateLimit(t) + res := rpcCall(t, conn, id, "contact.list", map[string]any{ + "grant_id": testGrantID, + "limit": 2, + }) + id++ + if res.IsError { + t.Fatalf("contact.list returned RPC error %d %q", res.ErrCode, res.ErrMsg) + } + if _, ok := res.Result["contacts"]; !ok { + t.Fatal("contact.list result missing contacts key") + } + }) + + t.Run("agentAccount.list", func(t *testing.T) { + acquireRateLimit(t) + res := rpcCall(t, conn, id, "agentAccount.list", map[string]any{}) + id++ + if res.IsError { + t.Fatalf("agentAccount.list returned RPC error %d %q", res.ErrCode, res.ErrMsg) + } + if _, ok := res.Result["accounts"]; !ok { + t.Fatal("agentAccount.list result missing accounts key") + } + }) + + t.Run("grant.list", func(t *testing.T) { + acquireRateLimit(t) + res := rpcCall(t, conn, id, "grant.list", map[string]any{}) + id++ + if res.IsError { + t.Fatalf("grant.list returned RPC error %d %q", res.ErrCode, res.ErrMsg) + } + if _, ok := res.Result["grants"]; !ok { + t.Fatal("grant.list result missing grants key") + } + }) + + t.Run("config.read", func(t *testing.T) { + acquireRateLimit(t) + res := rpcCall(t, conn, id, "config.read", map[string]any{}) + id++ + if res.IsError { + t.Fatalf("config.read returned RPC error %d %q", res.ErrCode, res.ErrMsg) + } + for _, key := range []string{"region", "ai_configured"} { + if _, ok := res.Result[key]; !ok { + t.Fatalf("config.read result missing %s key", key) + } + } + for _, key := range []string{"api_key", "client_secret", "grants", "ai", "gpg", "dashboard"} { + if _, ok := res.Result[key]; ok { + t.Fatalf("config.read result unexpectedly included %s key", key) + } + } + }) + + t.Run("email.list pagination", func(t *testing.T) { + acquireRateLimit(t) + firstRes := rpcCall(t, conn, id, "email.list", map[string]any{ + "grant_id": testGrantID, + "limit": 1, + }) + id++ + if firstRes.IsError { + t.Fatalf("email.list returned RPC error %d %q", firstRes.ErrCode, firstRes.ErrMsg) + } + + var nextCursor string + if raw, ok := firstRes.Result["next_cursor"]; ok && string(raw) != "null" { + if err := json.Unmarshal(raw, &nextCursor); err != nil { + t.Fatalf("unmarshal next_cursor: %v", err) + } + } + if nextCursor == "" { + t.Skip("no next_cursor available") + } + + acquireRateLimit(t) + res := rpcCall(t, conn, id, "email.list", map[string]any{ + "grant_id": testGrantID, + "limit": 1, + "page_token": nextCursor, + }) + id++ + if res.IsError { + t.Fatalf("email.list page 2 returned RPC error %d %q", res.ErrCode, res.ErrMsg) + } + if _, ok := res.Result["messages"]; !ok { + t.Fatal("email.list page 2 result missing messages key") + } + }) + + t.Run("email.get not found", func(t *testing.T) { + acquireRateLimit(t) + res := rpcCall(t, conn, id, "email.get", map[string]any{ + "grant_id": testGrantID, + "message_id": "definitely-not-a-real-id", + }) + id++ + if !res.IsError { + t.Fatal("email.get not found returned success") + } + }) +} diff --git a/internal/cli/integration/rpc_test.go b/internal/cli/integration/rpc_test.go new file mode 100644 index 0000000..0ec42a3 --- /dev/null +++ b/internal/cli/integration/rpc_test.go @@ -0,0 +1,161 @@ +//go:build integration +// +build integration + +package integration + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "os" + "os/exec" + "strconv" + "strings" + "sync" + "testing" + "time" + + "github.com/gorilla/websocket" +) + +func TestCLI_RPCServeHelp(t *testing.T) { + if testBinary == "" { + t.Skip("CLI binary not found") + } + + stdout, stderr, err := runCLI("rpc", "serve", "--help") + if err != nil { + t.Fatalf("rpc serve --help failed: %v\nstderr: %s", err, stderr) + } + + output := stdout + stderr + for _, want := range []string{"JSON-RPC", "serve"} { + if !strings.Contains(output, want) { + t.Errorf("expected rpc serve help to contain %q, got stdout: %s\nstderr: %s", want, stdout, stderr) + } + } +} + +func TestCLI_RPCServe_AuthAndEmailList(t *testing.T) { + skipIfMissingCreds(t) + + port := freeTCPPort(t) + addr := "127.0.0.1:" + strconv.Itoa(port) + ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, testBinary, "rpc", "serve", "--addr", addr) + cmd.Env = cliTestEnv(map[string]string{"NYLAS_WS_TOKEN": "integration-rpc-token"}) + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + if err := cmd.Start(); err != nil { + t.Fatalf("failed to start rpc server: %v", err) + } + + var stopOnce sync.Once + stopServer := func() { + stopOnce.Do(func() { + if cmd.Process != nil { + _ = cmd.Process.Signal(os.Interrupt) + } + _ = cmd.Wait() + }) + } + t.Cleanup(stopServer) + + wsURL := "ws://" + addr + "/ws" + goodHeader := http.Header{"Authorization": {"Bearer integration-rpc-token"}} + var conn *websocket.Conn + var lastErr error + deadline := time.Now().Add(10 * time.Second) + for time.Now().Before(deadline) { + var resp *http.Response + conn, resp, lastErr = websocket.DefaultDialer.Dial(wsURL, goodHeader) + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + if lastErr == nil { + break + } + time.Sleep(200 * time.Millisecond) + } + if lastErr != nil { + stopServer() + t.Fatalf("rpc server did not become ready: %v\nstdout: %s\nstderr: %s", lastErr, stdout.String(), stderr.String()) + } + defer func() { _ = conn.Close() }() + + badConn, resp, err := websocket.DefaultDialer.Dial(wsURL, http.Header{"Authorization": {"Bearer wrong-token"}}) + if badConn != nil { + _ = badConn.Close() + } + if resp != nil && resp.Body != nil { + defer func() { _ = resp.Body.Close() }() + } + if err == nil { + stopServer() + t.Fatalf("expected websocket dial with wrong bearer token to fail\nstdout: %s\nstderr: %s", stdout.String(), stderr.String()) + } + if resp != nil && resp.StatusCode != http.StatusUnauthorized { + stopServer() + t.Fatalf("wrong-token websocket status = %d, want %d\nstdout: %s\nstderr: %s", resp.StatusCode, http.StatusUnauthorized, stdout.String(), stderr.String()) + } + + acquireRateLimit(t) + request := map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "email.list", + "params": map[string]any{ + "grant_id": testGrantID, + "limit": 2, + }, + } + if err := conn.WriteJSON(request); err != nil { + stopServer() + t.Fatalf("failed to write email.list request: %v\nstdout: %s\nstderr: %s", err, stdout.String(), stderr.String()) + } + if err := conn.SetReadDeadline(time.Now().Add(30 * time.Second)); err != nil { + stopServer() + t.Fatalf("failed to set websocket read deadline: %v\nstdout: %s\nstderr: %s", err, stdout.String(), stderr.String()) + } + + for { + var response struct { + JSONRPC string `json:"jsonrpc"` + ID json.RawMessage `json:"id"` + Result map[string]json.RawMessage `json:"result"` + Error *struct { + Code int `json:"code"` + Message string `json:"message"` + } `json:"error"` + } + if err := conn.ReadJSON(&response); err != nil { + stopServer() + t.Fatalf("failed to read email.list response: %v\nstdout: %s\nstderr: %s", err, stdout.String(), stderr.String()) + } + if string(response.ID) != "1" { + continue + } + if response.Error != nil { + stopServer() + t.Fatalf("email.list returned RPC error %d %q\nstdout: %s\nstderr: %s", response.Error.Code, response.Error.Message, stdout.String(), stderr.String()) + } + if _, ok := response.Result["messages"]; !ok { + stopServer() + t.Fatalf("email.list result missing messages key: %s\nstdout: %s\nstderr: %s", string(mustMarshalJSON(response.Result)), stdout.String(), stderr.String()) + } + break + } +} + +func mustMarshalJSON(v any) []byte { + b, err := json.Marshal(v) + if err != nil { + return []byte("") + } + return b +} diff --git a/internal/cli/integration/rpc_testutil_test.go b/internal/cli/integration/rpc_testutil_test.go new file mode 100644 index 0000000..2ceb9d9 --- /dev/null +++ b/internal/cli/integration/rpc_testutil_test.go @@ -0,0 +1,170 @@ +//go:build integration +// +build integration + +package integration + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "os" + "os/exec" + "strconv" + "testing" + "time" + + "github.com/gorilla/websocket" +) + +// rpcTestToken is the fixed WS session token injected into the server subprocess for tests. +const rpcTestToken = "integration-rpc-token" + +// startRPCServer launches `nylas rpc serve` on a free loopback port with the test credentials and a +// known token, waits until it accepts WebSocket connections, and returns its addr + token. +// extraEnv overrides/augments the default test env (e.g. set NYLAS_GRANT_ID to drive the pollers). +// Server shutdown + process reap is registered via t.Cleanup. +func startRPCServer(t *testing.T, extraEnv map[string]string) (addr, token string) { + t.Helper() + if testBinary == "" { + t.Skip("CLI binary not found") + } + + addr = "127.0.0.1:" + strconv.Itoa(freeTCPPort(t)) + ctx, cancel := context.WithCancel(context.Background()) + + env := map[string]string{"NYLAS_WS_TOKEN": rpcTestToken} + for k, v := range extraEnv { + env[k] = v + } + + cmd := exec.CommandContext(ctx, testBinary, "rpc", "serve", "--addr", addr) + cmd.Env = cliTestEnv(env) + var stderr bytes.Buffer + cmd.Stderr = &stderr + + if err := cmd.Start(); err != nil { + cancel() + t.Fatalf("start rpc server: %v", err) + } + t.Cleanup(func() { + if cmd.Process != nil { + _ = cmd.Process.Signal(os.Interrupt) + } + cancel() + _ = cmd.Wait() + }) + + // Readiness: retry-dial until the WebSocket handshake succeeds. + deadline := time.Now().Add(15 * time.Second) + for time.Now().Before(deadline) { + c, resp, err := websocket.DefaultDialer.Dial("ws://"+addr+"/ws", + http.Header{"Authorization": {"Bearer " + rpcTestToken}}) + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + if err == nil { + _ = c.Close() + return addr, rpcTestToken + } + time.Sleep(200 * time.Millisecond) + } + t.Fatalf("rpc server did not become ready\nstderr: %s", stderr.String()) + return "", "" +} + +// dialRPC opens an authenticated WebSocket connection (Authorization: Bearer). Close is registered. +func dialRPC(t *testing.T, addr, token string) *websocket.Conn { + t.Helper() + conn, resp, err := websocket.DefaultDialer.Dial("ws://"+addr+"/ws", + http.Header{"Authorization": {"Bearer " + token}}) + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + if err != nil { + t.Fatalf("dial rpc: %v", err) + } + t.Cleanup(func() { _ = conn.Close() }) + return conn +} + +// rpcResult is the parsed outcome of a JSON-RPC call: either Result is set, or IsError with code/message. +type rpcResult struct { + Result map[string]json.RawMessage + IsError bool + ErrCode int + ErrMsg string +} + +// rpcCall sends a request, reads until the response with the matching id arrives (skipping any +// interleaved notifications), and returns the parsed result or error. +func rpcCall(t *testing.T, conn *websocket.Conn, id int, method string, params map[string]any) rpcResult { + t.Helper() + req := map[string]any{"jsonrpc": "2.0", "id": id, "method": method} + if params != nil { + req["params"] = params + } + if err := conn.WriteJSON(req); err != nil { + t.Fatalf("write %s: %v", method, err) + } + if err := conn.SetReadDeadline(time.Now().Add(30 * time.Second)); err != nil { + t.Fatalf("set read deadline: %v", err) + } + for { + var resp struct { + ID json.RawMessage `json:"id"` + Result map[string]json.RawMessage `json:"result"` + Error *struct { + Code int `json:"code"` + Message string `json:"message"` + } `json:"error"` + } + if err := conn.ReadJSON(&resp); err != nil { + t.Fatalf("read %s response: %v", method, err) + } + if string(resp.ID) != strconv.Itoa(id) { + continue // notification or a different id + } + if resp.Error != nil { + return rpcResult{IsError: true, ErrCode: resp.Error.Code, ErrMsg: resp.Error.Message} + } + return rpcResult{Result: resp.Result} + } +} + +// rpcID extracts a string "id" field from a result object (most create/get results carry one). +func rpcID(t *testing.T, result map[string]json.RawMessage) string { + t.Helper() + var s string + if raw, ok := result["id"]; ok { + _ = json.Unmarshal(raw, &s) + } + return s +} + +// waitForNotification reads frames until a server->client notification (no id) with the given method +// (and optional matching predicate) arrives, or the timeout elapses. +func waitForNotification(t *testing.T, conn *websocket.Conn, method string, timeout time.Duration, match func(json.RawMessage) bool) (json.RawMessage, bool) { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if err := conn.SetReadDeadline(deadline); err != nil { + return nil, false + } + var msg struct { + ID json.RawMessage `json:"id"` + Method string `json:"method"` + Params json.RawMessage `json:"params"` + } + if err := conn.ReadJSON(&msg); err != nil { + return nil, false + } + if msg.ID != nil { + continue // a response, not a notification + } + if msg.Method == method && (match == nil || match(msg.Params)) { + return msg.Params, true + } + } + return nil, false +} diff --git a/internal/cli/integration/rpc_writes_test.go b/internal/cli/integration/rpc_writes_test.go new file mode 100644 index 0000000..b28e957 --- /dev/null +++ b/internal/cli/integration/rpc_writes_test.go @@ -0,0 +1,187 @@ +//go:build integration +// +build integration + +package integration + +import ( + "encoding/json" + "testing" + "time" +) + +func TestCLI_RPC_Writes(t *testing.T) { + skipIfMissingCreds(t) + + addr, token := startRPCServer(t, nil) + conn := dialRPC(t, addr, token) + id := 0 + nextID := func() int { + id++ + return id + } + call := func(t *testing.T, method string, params map[string]any) rpcResult { + t.Helper() + acquireRateLimit(t) + return rpcCall(t, conn, nextID(), method, params) + } + requireNoError := func(t *testing.T, method string, result rpcResult) { + t.Helper() + if result.IsError { + t.Fatalf("%s returned RPC error %d %q", method, result.ErrCode, result.ErrMsg) + } + } + requireInvalidParams := func(t *testing.T, method string, result rpcResult) { + t.Helper() + if !result.IsError || result.ErrCode != -32602 { + t.Fatalf("%s error = (%v, %d, %q), want code -32602", method, result.IsError, result.ErrCode, result.ErrMsg) + } + } + requireDeleted := func(t *testing.T, method string, result rpcResult) { + t.Helper() + requireNoError(t, method, result) + var deleted bool + if raw, ok := result.Result["deleted"]; ok { + _ = json.Unmarshal(raw, &deleted) + } + if !deleted { + t.Fatalf("%s deleted = false, result = %v", method, result.Result) + } + } + skipCreateError := func(t *testing.T, method string, result rpcResult) { + t.Helper() + if result.IsError && result.ErrCode == -32603 { + t.Skipf("%s not supported by this provider/account: RPC error %d %q", method, result.ErrCode, result.ErrMsg) + } + } + + t.Run("draft round-trip", func(t *testing.T) { + requireInvalidParams(t, "draft.update", call(t, "draft.update", map[string]any{ + "grant_id": testGrantID, + })) + + create := call(t, "draft.create", map[string]any{ + "grant_id": testGrantID, + "subject": "RPC-IT draft (delete me)", + "body": "x", + }) + skipCreateError(t, "draft.create", create) + requireNoError(t, "draft.create", create) + draftID := rpcID(t, create.Result) + if draftID == "" { + t.Fatal("draft.create result missing id") + } + + deleted := false + t.Cleanup(func() { + if deleted { + return + } + requireDeleted(t, "draft.delete cleanup", call(t, "draft.delete", map[string]any{ + "grant_id": testGrantID, + "draft_id": draftID, + })) + }) + + requireNoError(t, "draft.update", call(t, "draft.update", map[string]any{ + "grant_id": testGrantID, + "draft_id": draftID, + "subject": "RPC-IT draft updated", + "body": "y", + })) + requireDeleted(t, "draft.delete", call(t, "draft.delete", map[string]any{ + "grant_id": testGrantID, + "draft_id": draftID, + })) + deleted = true + }) + + t.Run("contact round-trip", func(t *testing.T) { + requireInvalidParams(t, "contact.update", call(t, "contact.update", map[string]any{ + "grant_id": testGrantID, + })) + + create := call(t, "contact.create", map[string]any{ + "grant_id": testGrantID, + "given_name": "RPCIT", + "surname": "DeleteMe", + }) + skipCreateError(t, "contact.create", create) + requireNoError(t, "contact.create", create) + contactID := rpcID(t, create.Result) + if contactID == "" { + t.Fatal("contact.create result missing id") + } + + deleted := false + t.Cleanup(func() { + if deleted { + return + } + requireDeleted(t, "contact.delete cleanup", call(t, "contact.delete", map[string]any{ + "grant_id": testGrantID, + "contact_id": contactID, + })) + }) + + requireNoError(t, "contact.update", call(t, "contact.update", map[string]any{ + "grant_id": testGrantID, + "contact_id": contactID, + "given_name": "RPCITUpdated", + })) + requireDeleted(t, "contact.delete", call(t, "contact.delete", map[string]any{ + "grant_id": testGrantID, + "contact_id": contactID, + })) + deleted = true + }) + + t.Run("event round-trip", func(t *testing.T) { + requireInvalidParams(t, "event.update", call(t, "event.update", map[string]any{ + "grant_id": testGrantID, + "calendar_id": "primary", + })) + + start := time.Now().Add(30 * 24 * time.Hour).Unix() + create := call(t, "event.create", map[string]any{ + "grant_id": testGrantID, + "calendar_id": "primary", + "title": "RPC-IT event (delete me)", + "when": map[string]any{ + "object": "timespan", + "start_time": start, + "end_time": start + int64(time.Hour/time.Second), + }, + }) + skipCreateError(t, "event.create", create) + requireNoError(t, "event.create", create) + eventID := rpcID(t, create.Result) + if eventID == "" { + t.Fatal("event.create result missing id") + } + + deleted := false + t.Cleanup(func() { + if deleted { + return + } + requireDeleted(t, "event.delete cleanup", call(t, "event.delete", map[string]any{ + "grant_id": testGrantID, + "calendar_id": "primary", + "event_id": eventID, + })) + }) + + requireNoError(t, "event.update", call(t, "event.update", map[string]any{ + "grant_id": testGrantID, + "calendar_id": "primary", + "event_id": eventID, + "title": "RPC-IT event updated", + })) + requireDeleted(t, "event.delete", call(t, "event.delete", map[string]any{ + "grant_id": testGrantID, + "calendar_id": "primary", + "event_id": eventID, + })) + deleted = true + }) +} diff --git a/internal/cli/rpc/rpc.go b/internal/cli/rpc/rpc.go new file mode 100644 index 0000000..544d4f6 --- /dev/null +++ b/internal/cli/rpc/rpc.go @@ -0,0 +1,16 @@ +// Package rpc provides JSON-RPC server commands. +package rpc + +import "github.com/spf13/cobra" + +// NewRPCCmd creates the rpc command with all subcommands. +func NewRPCCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "rpc", + Short: "JSON-RPC WebSocket server for Nylas", + } + + cmd.AddCommand(newServeCmd()) + + return cmd +} diff --git a/internal/cli/rpc/serve.go b/internal/cli/rpc/serve.go new file mode 100644 index 0000000..99ece7c --- /dev/null +++ b/internal/cli/rpc/serve.go @@ -0,0 +1,197 @@ +package rpc + +import ( + "context" + "fmt" + "os" + "os/signal" + "syscall" + "time" + + "github.com/nylas/cli/internal/adapters/audit" + "github.com/nylas/cli/internal/adapters/config" + "github.com/nylas/cli/internal/adapters/keyring" + "github.com/nylas/cli/internal/adapters/rpcserver" + otpapp "github.com/nylas/cli/internal/app/otp" + "github.com/nylas/cli/internal/cli/common" + "github.com/spf13/cobra" +) + +const ( + envWSAddr = "NYLAS_WS_ADDR" + defaultAddr = "127.0.0.1:7368" +) + +func newServeCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "serve", + Short: "Start the JSON-RPC WebSocket server", + RunE: runServe, + } + + cmd.Flags().String("addr", "", "address to bind (or NYLAS_WS_ADDR)") + cmd.Flags().Bool("allow-remote", false, "allow binding to a non-loopback address") + + return cmd +} + +func runServe(cmd *cobra.Command, args []string) error { + addr, err := cmd.Flags().GetString("addr") + if err != nil { + return fmt.Errorf("read --addr: %w", err) + } + if addr == "" { + addr = os.Getenv(envWSAddr) + } + if addr == "" { + addr = defaultAddr + } + + allowRemote, err := cmd.Flags().GetBool("allow-remote") + if err != nil { + return fmt.Errorf("read --allow-remote: %w", err) + } + + loopback, err := rpcserver.IsLoopback(addr) + if err != nil { + return err + } + if !loopback && !allowRemote { + return fmt.Errorf("refusing to bind credential-holding RPC socket to non-loopback address %q without --allow-remote", addr) + } + if !loopback { + _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "WARNING: binding credential-holding RPC socket to non-loopback address %s\n", addr) + } + + client, err := common.GetNylasClient() + if err != nil { + return err + } + grantID, _ := common.GetGrantID(nil) + + store, err := keyring.NewSecretStore(config.DefaultConfigDir()) + if err != nil { + return fmt.Errorf("open secret store: %w", err) + } + token, err := rpcserver.ResolveToken(store, os.Getenv) + if err != nil { + return err + } + + d := rpcserver.NewDispatcher() + d.LogError = func(err error) { _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "rpc handler error: %v\n", err) } + rpcserver.RegisterEmailHandlers(d, client, grantID) + rpcserver.RegisterThreadHandlers(d, client, grantID) + rpcserver.RegisterCalendarHandlers(d, client, grantID) + rpcserver.RegisterContactHandlers(d, client, grantID) + rpcserver.RegisterAgentHandlers(d, client) + cfgStore := config.NewDefaultFileStore() + rpcserver.RegisterConfigHandlers(d, cfgStore) + grantStore, gerr := common.NewDefaultGrantStore() + if gerr == nil { + rpcserver.RegisterGrantHandlers(d, grantStore) + } else { + _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "grant/otp handlers disabled: %v\n", gerr) + } + + // Phase 2 writes (the client confirms before calling; the server executes immediately). + rpcserver.RegisterEmailWriteHandlers(d, client, grantID) + rpcserver.RegisterEmailExtHandlers(d, client, grantID) + rpcserver.RegisterThreadWriteHandlers(d, client, grantID) + rpcserver.RegisterCalendarWriteHandlers(d, client, grantID) + rpcserver.RegisterCalendarExtHandlers(d, client, grantID) + rpcserver.RegisterContactWriteHandlers(d, client, grantID) + rpcserver.RegisterContactExtHandlers(d, client, grantID) + + // Extended domains. + rpcserver.RegisterDraftHandlers(d, client, grantID) + rpcserver.RegisterNotetakerHandlers(d, client, grantID) + rpcserver.RegisterSchedulerHandlers(d, client, grantID) + rpcserver.RegisterTemplateWorkflowHandlers(d, client, grantID) + rpcserver.RegisterAdminHandlers(d, client) + rpcserver.RegisterAuthHandlers(d, client, grantID) + if auditStore, aerr := audit.NewFileStore(""); aerr == nil { + rpcserver.RegisterAuditHandlers(d, auditStore) + } else { + _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "audit handlers disabled: %v\n", aerr) + } + if gerr == nil { + rpcserver.RegisterOTPHandlers(d, otpapp.NewService(client, grantStore, cfgStore)) + } + + srv := rpcserver.NewServer(rpcserver.Config{ + Addr: addr, + Token: token, + }, d) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ctrl := rpcserver.NewIntervalController(5*time.Second, 30*time.Second) + rpcserver.RegisterFocusHandler(d, ctrl) + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + defer signal.Stop(sigChan) + + go func() { + <-sigChan + cancel() + }() + + if grantID == "" { + _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "no default grant configured; live notifications disabled — set a default grant to enable them") + } else { + since := time.Now().Unix() + onErr := func(err error) { + _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "rpc poll error: %v\n", err) + } + startPoller := func(name string, run func() error) { + go func() { + if err := run(); err != nil && ctx.Err() == nil { + _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "rpc %s poller stopped: %v\n", name, err) + } + }() + } + + mp := rpcserver.NewMessagePoller(client, grantID, since, srv.Broadcast) + startPoller("message", func() error { return rpcserver.RunAdaptive(ctx, ctrl, onErr, mp.PollOnce) }) + + tp := rpcserver.NewThreadPoller(client, grantID, since, srv.Broadcast) + startPoller("thread", func() error { return rpcserver.RunAdaptive(ctx, ctrl, onErr, tp.PollOnce) }) + + calendarIDs := []string{"primary"} + calCtx, calCancel := context.WithTimeout(ctx, 10*time.Second) + cals, cerr := client.GetCalendars(calCtx, grantID) + calCancel() + if cerr != nil { + onErr(fmt.Errorf("list calendars for event pollers: %w", cerr)) + } else if len(cals) > 0 { + calendarIDs = calendarIDs[:0] + for _, cal := range cals { + if cal.ID != "" { + calendarIDs = append(calendarIDs, cal.ID) + } + } + if len(calendarIDs) == 0 { + calendarIDs = []string{"primary"} + } + } + + // ponytail: per-calendar polling scales API calls with calendar count; webhooks are the upgrade path. + for _, calendarID := range calendarIDs { + ep := rpcserver.NewEventPoller(client, grantID, calendarID, since, srv.Broadcast) + startPoller("event", func() error { return rpcserver.RunAdaptive(ctx, ctrl, onErr, ep.PollOnce) }) + } + + // ponytail: contacts have no server-side time filter — refetch+diff on a slow cadence. + cp := rpcserver.NewContactPoller(client, grantID, srv.Broadcast) + startPoller("contact", func() error { return cp.Run(ctx, 60*time.Second, onErr) }) + } + + _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Nylas RPC WebSocket listening on %s\n", addr) + _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "Authenticate with Authorization: Bearer or ?token=.") + _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "The token is stored in the keyring or read from NYLAS_WS_TOKEN.") + + return srv.Serve(ctx) +} diff --git a/internal/domain/calendar.go b/internal/domain/calendar.go index 4aedcd5..43712d5 100644 --- a/internal/domain/calendar.go +++ b/internal/domain/calendar.go @@ -217,6 +217,8 @@ type EventQueryParams struct { ShowCancelled bool `json:"show_cancelled,omitempty"` Start int64 `json:"start,omitempty"` // Unix timestamp End int64 `json:"end,omitempty"` // Unix timestamp + UpdatedAfter int64 `json:"updated_after,omitempty"` + UpdatedBefore int64 `json:"updated_before,omitempty"` MetadataPair string `json:"metadata_pair,omitempty"` Busy *bool `json:"busy,omitempty"` OrderBy string `json:"order_by,omitempty"` // start, end diff --git a/internal/domain/contact.go b/internal/domain/contact.go index c9b8b0e..da766af 100644 --- a/internal/domain/contact.go +++ b/internal/domain/contact.go @@ -17,6 +17,7 @@ type Contact struct { Notes string `json:"notes,omitempty"` PictureURL string `json:"picture_url,omitempty"` Picture string `json:"picture,omitempty"` // Base64-encoded image data (when profile_picture=true) + UpdatedAt int64 `json:"updated_at,omitempty"` Emails []ContactEmail `json:"emails,omitempty"` PhoneNumbers []ContactPhone `json:"phone_numbers,omitempty"` WebPages []ContactWebPage `json:"web_pages,omitempty"` diff --git a/internal/ports/messages.go b/internal/ports/messages.go index a5c5b62..838335a 100644 --- a/internal/ports/messages.go +++ b/internal/ports/messages.go @@ -89,6 +89,9 @@ type MessageClient interface { // GetThreads retrieves threads with query parameters. GetThreads(ctx context.Context, grantID string, params *domain.ThreadQueryParams) ([]domain.Thread, error) + // GetThreadsWithCursor retrieves threads with cursor-based pagination. + GetThreadsWithCursor(ctx context.Context, grantID string, params *domain.ThreadQueryParams) (*domain.ThreadListResponse, error) + // GetThread retrieves a specific thread. GetThread(ctx context.Context, grantID, threadID string) (*domain.Thread, error) From e4ff2b6783f7f9028a29e078eaac5ddd9788e2cf Mon Sep 17 00:00:00 2001 From: Qasim Date: Fri, 26 Jun 2026 08:52:26 -0400 Subject: [PATCH 2/2] TW-5722: make RPC poll intervals configurable via env and client.pollConfig Poll intervals were hardcoded (5s/30s focused/idle, 60s contacts). Make them configurable at startup via env vars and live over RPC: - NYLAS_WS_POLL_FAST / _IDLE / _CONTACTS seed the intervals (invalid or non-positive values fall back to the previous defaults). - New client.pollConfig request method updates the live intervals (optional Go-duration fields; omitted fields unchanged) and returns the effective values; takes effect on the next poll cycle, no restart. Contacts now run through RunAdaptive on their own controller, unifying all four pollers on one driver. That retired the dead MessagePoller.Run / ThreadPoller.Run / ContactPoller.Run / runTicker (only ContactPoller.Run was ever wired), along with their tests. Docs: documented the env vars and client.pollConfig in docs/RPC.md. --- docs/RPC.md | 26 ++++- internal/adapters/rpcserver/incremental.go | 102 +++++++++++++++--- .../adapters/rpcserver/incremental_test.go | 85 +++++++++++++++ .../adapters/rpcserver/poller_contacts.go | 5 - .../rpcserver/poller_contacts_test.go | 15 --- .../adapters/rpcserver/poller_messages.go | 6 -- .../rpcserver/poller_messages_test.go | 14 --- internal/adapters/rpcserver/poller_threads.go | 6 -- .../adapters/rpcserver/poller_threads_test.go | 14 --- internal/cli/rpc/serve.go | 25 ++++- internal/cli/rpc/serve_test.go | 32 ++++++ 11 files changed, 250 insertions(+), 80 deletions(-) create mode 100644 internal/cli/rpc/serve_test.go diff --git a/docs/RPC.md b/docs/RPC.md index 2ad66ee..de58936 100644 --- a/docs/RPC.md +++ b/docs/RPC.md @@ -133,6 +133,12 @@ The server holds live Nylas credentials, so the local socket is a real trust bou | `--allow-remote` | permit a non-loopback bind (warns) | `false` | | `NYLAS_WS_TOKEN` | inject the session token (headless/CI) | auto-generated, keyring-brokered | | `NYLAS_DISABLE_KEYRING` | store token/creds in `~/.config/nylas` instead of the keyring | `false` | +| `NYLAS_WS_POLL_FAST` | message/thread/event poll interval while focused (Go duration) | `5s` | +| `NYLAS_WS_POLL_IDLE` | message/thread/event poll interval while idle (Go duration) | `30s` | +| `NYLAS_WS_POLL_CONTACTS` | contact refetch interval (Go duration) | `60s` | + +Invalid or non-positive poll durations fall back to the default. The intervals can also be changed +at runtime via the [`client.pollConfig`](#adaptive-polling) method — no restart required. The server resolves the Nylas API credentials and default grant the same way the rest of the CLI does (keyring, or env/file when `NYLAS_DISABLE_KEYRING=true`). Live pollers run only when a @@ -333,8 +339,24 @@ Send a `client.focus` **notification** (no `id`) to scale the poll interval: { "jsonrpc": "2.0", "method": "client.focus", "params": { "focused": true } } ``` -- `focused: true` → fast interval (5s) for message/thread/event pollers. -- `focused: false` → idle interval (30s). Contacts always poll on a slow 60s cadence. +- `focused: true` → fast interval (default 5s) for message/thread/event pollers. +- `focused: false` → idle interval (default 30s). Contacts poll on their own cadence (default 60s). + +To change the interval **values** themselves at runtime, call `client.pollConfig` (a request, not a +notification). All fields are optional Go durations (`"2s"`, `"1m"`); omitted fields are left +unchanged, and the result reports the effective values. `fast`/`idle` drive the message/thread/event +pollers; `contacts` drives the contact poller. + +```jsonc +// → request +{ "jsonrpc": "2.0", "id": 9, "method": "client.pollConfig", + "params": { "fast": "2s", "idle": "45s", "contacts": "90s" } } +// ← result +{ "jsonrpc": "2.0", "id": 9, "result": { "fast": "2s", "idle": "45s", "contacts": "1m30s" } } +``` + +A non-positive or unparseable duration returns `-32602 invalid params`. Startup defaults come from +the `NYLAS_WS_POLL_*` env vars (see [Configuration](#configuration)). --- diff --git a/internal/adapters/rpcserver/incremental.go b/internal/adapters/rpcserver/incremental.go index 59ed5dc..5243658 100644 --- a/internal/adapters/rpcserver/incremental.go +++ b/internal/adapters/rpcserver/incremental.go @@ -35,25 +35,29 @@ func (c *IntervalController) Current() time.Duration { return c.idle } -type incrementalState struct { - cursor int64 - boundaryIDs map[string]struct{} +// SetIntervals updates the focused/idle durations live. Non-positive values are +// ignored, so a caller can change one bound without touching the other. +func (c *IntervalController) SetIntervals(fast, idle time.Duration) { + c.mu.Lock() + defer c.mu.Unlock() + if fast > 0 { + c.fast = fast + } + if idle > 0 { + c.idle = idle + } } -func runTicker(ctx context.Context, interval time.Duration, onError func(error), pollOnce func(context.Context) error) error { - ticker := time.NewTicker(interval) - defer ticker.Stop() +// Intervals returns the current focused and idle durations. +func (c *IntervalController) Intervals() (fast, idle time.Duration) { + c.mu.Lock() + defer c.mu.Unlock() + return c.fast, c.idle +} - for { - select { - case <-ctx.Done(): - return ctx.Err() - case <-ticker.C: - if err := pollOnce(ctx); err != nil && onError != nil { - onError(err) - } - } - } +type incrementalState struct { + cursor int64 + boundaryIDs map[string]struct{} } func RunAdaptive(ctx context.Context, ctrl *IntervalController, onError func(error), pollOnce func(context.Context) error) error { @@ -90,6 +94,72 @@ func RegisterFocusHandler(d *Dispatcher, ctrl *IntervalController) { }) } +type pollConfigParams struct { + Fast string `json:"fast,omitempty"` + Idle string `json:"idle,omitempty"` + Contacts string `json:"contacts,omitempty"` +} + +type pollConfigResult struct { + Fast string `json:"fast"` + Idle string `json:"idle"` + Contacts string `json:"contacts"` +} + +// RegisterPollConfigHandler exposes client.pollConfig, which reads and optionally +// updates the live polling intervals. Durations use Go syntax (e.g. "2s", "1m"); +// omitted or empty fields are left unchanged. The result always reports the +// effective values. ctrl drives messages/threads/events (focused/idle), and +// contactCtrl drives contacts (a single interval, so focused == idle). +func RegisterPollConfigHandler(d *Dispatcher, ctrl, contactCtrl *IntervalController) { + d.Register("client.pollConfig", func(_ context.Context, params json.RawMessage) (any, error) { + var p pollConfigParams + if err := decodeParams(params, &p); err != nil { + return nil, err + } + + fast, err := parsePollInterval("fast", p.Fast) + if err != nil { + return nil, err + } + idle, err := parsePollInterval("idle", p.Idle) + if err != nil { + return nil, err + } + contacts, err := parsePollInterval("contacts", p.Contacts) + if err != nil { + return nil, err + } + + ctrl.SetIntervals(fast, idle) + contactCtrl.SetIntervals(contacts, contacts) + + effFast, effIdle := ctrl.Intervals() + effContacts, _ := contactCtrl.Intervals() + return pollConfigResult{ + Fast: effFast.String(), + Idle: effIdle.String(), + Contacts: effContacts.String(), + }, nil + }) +} + +// parsePollInterval returns 0 for an empty value (meaning "leave unchanged"). +// A non-empty value must parse as a positive Go duration. +func parsePollInterval(field, value string) (time.Duration, error) { + if value == "" { + return 0, nil + } + d, err := time.ParseDuration(value) + if err != nil { + return 0, NewRPCError(InvalidParams, field+` must be a duration (e.g. "2s", "1m")`, err.Error()) + } + if d <= 0 { + return 0, NewRPCError(InvalidParams, field+" must be positive", nil) + } + return d, nil +} + func pollIncremental[T any]( ctx context.Context, st *incrementalState, diff --git a/internal/adapters/rpcserver/incremental_test.go b/internal/adapters/rpcserver/incremental_test.go index 93cc8c7..323d5bd 100644 --- a/internal/adapters/rpcserver/incremental_test.go +++ b/internal/adapters/rpcserver/incremental_test.go @@ -2,6 +2,7 @@ package rpcserver import ( "context" + "encoding/json" "errors" "sync" "testing" @@ -91,6 +92,90 @@ func TestRunAdaptive_PollsReportsErrorsAndReturnsContextError(t *testing.T) { } } +func TestIntervalController_SetIntervals(t *testing.T) { + ctrl := NewIntervalController(5*time.Second, 30*time.Second) + + // Non-positive values leave the corresponding bound unchanged. + ctrl.SetIntervals(2*time.Second, 0) + if fast, idle := ctrl.Intervals(); fast != 2*time.Second || idle != 30*time.Second { + t.Fatalf("Intervals() = (%v, %v), want (2s, 30s)", fast, idle) + } + + ctrl.SetIntervals(0, time.Minute) + if fast, idle := ctrl.Intervals(); fast != 2*time.Second || idle != time.Minute { + t.Fatalf("Intervals() = (%v, %v), want (2s, 1m)", fast, idle) + } + + // The live value follows focus state after an update. + ctrl.SetFocused(true) + if got := ctrl.Current(); got != 2*time.Second { + t.Fatalf("Current() = %v, want 2s", got) + } +} + +func TestRegisterPollConfigHandler(t *testing.T) { + ctrl := NewIntervalController(5*time.Second, 30*time.Second) + contactCtrl := NewIntervalController(60*time.Second, 60*time.Second) + d := NewDispatcher() + RegisterPollConfigHandler(d, ctrl, contactCtrl) + + t.Run("updates provided fields and reports effective values", func(t *testing.T) { + got := d.Dispatch(context.Background(), + []byte(`{"jsonrpc":"2.0","id":1,"method":"client.pollConfig","params":{"fast":"2s","contacts":"90s"}}`)) + + var resp struct { + Result pollConfigResult `json:"result"` + Error *RPCError `json:"error"` + } + if err := json.Unmarshal(got, &resp); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + if resp.Error != nil { + t.Fatalf("unexpected error: %+v", resp.Error) + } + // fast changed, idle left untouched, contacts changed. + if resp.Result.Fast != "2s" || resp.Result.Idle != "30s" || resp.Result.Contacts != "1m30s" { + t.Fatalf("result = %+v, want {2s 30s 1m30s}", resp.Result) + } + if fast, _ := ctrl.Intervals(); fast != 2*time.Second { + t.Fatalf("controller fast = %v, want 2s", fast) + } + if c, _ := contactCtrl.Intervals(); c != 90*time.Second { + t.Fatalf("contact interval = %v, want 90s", c) + } + }) + + t.Run("rejects an invalid duration", func(t *testing.T) { + got := d.Dispatch(context.Background(), + []byte(`{"jsonrpc":"2.0","id":2,"method":"client.pollConfig","params":{"idle":"nope"}}`)) + + var resp struct { + Error *RPCError `json:"error"` + } + if err := json.Unmarshal(got, &resp); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + if resp.Error == nil || resp.Error.Code != InvalidParams { + t.Fatalf("error = %+v, want InvalidParams", resp.Error) + } + }) + + t.Run("rejects a non-positive duration", func(t *testing.T) { + got := d.Dispatch(context.Background(), + []byte(`{"jsonrpc":"2.0","id":3,"method":"client.pollConfig","params":{"fast":"0s"}}`)) + + var resp struct { + Error *RPCError `json:"error"` + } + if err := json.Unmarshal(got, &resp); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + if resp.Error == nil || resp.Error.Code != InvalidParams { + t.Fatalf("error = %+v, want InvalidParams", resp.Error) + } + }) +} + func TestRegisterFocusHandler(t *testing.T) { fast := time.Millisecond idle := time.Second diff --git a/internal/adapters/rpcserver/poller_contacts.go b/internal/adapters/rpcserver/poller_contacts.go index a37de05..314727e 100644 --- a/internal/adapters/rpcserver/poller_contacts.go +++ b/internal/adapters/rpcserver/poller_contacts.go @@ -8,7 +8,6 @@ import ( "sort" "strconv" "strings" - "time" "github.com/nylas/cli/internal/domain" "github.com/nylas/cli/internal/ports" @@ -110,10 +109,6 @@ func (p *ContactPoller) PollOnce(ctx context.Context) error { return nil } -func (p *ContactPoller) Run(ctx context.Context, interval time.Duration, onError func(error)) error { - return runTicker(ctx, interval, onError, p.PollOnce) -} - func contactFingerprint(c domain.Contact) string { records := []string{} appendRecord := func(parts ...string) { diff --git a/internal/adapters/rpcserver/poller_contacts_test.go b/internal/adapters/rpcserver/poller_contacts_test.go index 1121c2f..f3ea25c 100644 --- a/internal/adapters/rpcserver/poller_contacts_test.go +++ b/internal/adapters/rpcserver/poller_contacts_test.go @@ -7,7 +7,6 @@ import ( "fmt" "reflect" "testing" - "time" "github.com/nylas/cli/internal/domain" ) @@ -375,20 +374,6 @@ func TestContactPoller_PollOnce_ReturnsErrorWithoutCommitWhenPageCapTruncates(t } } -func TestContactPoller_Run_ReturnsContextErrorOnCancel(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - poller := NewContactPoller(&fakeContactClient{}, "grant-123", func(method string, params any) error { - t.Fatal("notify should not be called") - return nil - }) - - if err := poller.Run(ctx, time.Hour, nil); !errors.Is(err, context.Canceled) { - t.Fatalf("Run() error = %v, want %v", err, context.Canceled) - } -} - func scriptedFakeContactClient(t *testing.T, polls []contactPollScript, errs map[string]error) *fakeContactClient { t.Helper() diff --git a/internal/adapters/rpcserver/poller_messages.go b/internal/adapters/rpcserver/poller_messages.go index 8e54ab4..6a7a8d5 100644 --- a/internal/adapters/rpcserver/poller_messages.go +++ b/internal/adapters/rpcserver/poller_messages.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "time" "github.com/nylas/cli/internal/domain" "github.com/nylas/cli/internal/ports" @@ -93,8 +92,3 @@ func (p *MessagePoller) fetch(ctx context.Context, queryAfter int64) ([]domain.M // ponytail: cap polling bursts at 20 pages; webhooks are the real fix for larger inbox spikes. return messages, nil } - -// Run polls until ctx is cancelled. -func (p *MessagePoller) Run(ctx context.Context, interval time.Duration, onError func(error)) error { - return runTicker(ctx, interval, onError, p.PollOnce) -} diff --git a/internal/adapters/rpcserver/poller_messages_test.go b/internal/adapters/rpcserver/poller_messages_test.go index 4187f68..aea1cee 100644 --- a/internal/adapters/rpcserver/poller_messages_test.go +++ b/internal/adapters/rpcserver/poller_messages_test.go @@ -290,20 +290,6 @@ func TestMessagePoller_PollOnce_ReturnsNotifyError(t *testing.T) { } } -func TestMessagePoller_Run_ReturnsContextErrorOnCancel(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - poller := NewMessagePoller(&fakePollClient{}, "grant-123", 0, func(method string, params any) error { - t.Fatal("notify should not be called") - return nil - }) - - if err := poller.Run(ctx, time.Hour, nil); !errors.Is(err, context.Canceled) { - t.Fatalf("Run() error = %v, want %v", err, context.Canceled) - } -} - func pollMessages(newest, oldest int64) []domain.Message { var messages []domain.Message for ts := newest; ts >= oldest; ts-- { diff --git a/internal/adapters/rpcserver/poller_threads.go b/internal/adapters/rpcserver/poller_threads.go index 88182ae..a90daa2 100644 --- a/internal/adapters/rpcserver/poller_threads.go +++ b/internal/adapters/rpcserver/poller_threads.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "time" "github.com/nylas/cli/internal/domain" "github.com/nylas/cli/internal/ports" @@ -84,8 +83,3 @@ func (p *ThreadPoller) fetch(ctx context.Context, queryAfter int64) ([]domain.Th // ponytail: cap polling bursts at 20 pages; webhooks are the real fix for larger inbox spikes. return threads, nil } - -// Run polls until ctx is cancelled. -func (p *ThreadPoller) Run(ctx context.Context, interval time.Duration, onError func(error)) error { - return runTicker(ctx, interval, onError, p.PollOnce) -} diff --git a/internal/adapters/rpcserver/poller_threads_test.go b/internal/adapters/rpcserver/poller_threads_test.go index 3f39de0..3d05188 100644 --- a/internal/adapters/rpcserver/poller_threads_test.go +++ b/internal/adapters/rpcserver/poller_threads_test.go @@ -209,20 +209,6 @@ func TestThreadPoller_PollOnce_ReturnsClientError(t *testing.T) { } } -func TestThreadPoller_Run_ReturnsContextErrorOnCancel(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - poller := NewThreadPoller(&fakeThreadClient{}, "grant-123", 0, func(method string, params any) error { - t.Fatal("notify should not be called") - return nil - }) - - if err := poller.Run(ctx, time.Hour, nil); !errors.Is(err, context.Canceled) { - t.Fatalf("Run() error = %v, want %v", err, context.Canceled) - } -} - func pollThread(id string, unix int64) domain.Thread { return domain.Thread{ ID: id, diff --git a/internal/cli/rpc/serve.go b/internal/cli/rpc/serve.go index 99ece7c..b4ed1e9 100644 --- a/internal/cli/rpc/serve.go +++ b/internal/cli/rpc/serve.go @@ -20,8 +20,23 @@ import ( const ( envWSAddr = "NYLAS_WS_ADDR" defaultAddr = "127.0.0.1:7368" + + envPollFast = "NYLAS_WS_POLL_FAST" + envPollIdle = "NYLAS_WS_POLL_IDLE" + envPollContacts = "NYLAS_WS_POLL_CONTACTS" ) +// pollInterval reads a positive Go duration from env (e.g. "2s", "1m"), +// falling back to def when unset or invalid. +func pollInterval(getenv func(string) string, key string, def time.Duration) time.Duration { + if v := getenv(key); v != "" { + if d, err := time.ParseDuration(v); err == nil && d > 0 { + return d + } + } + return def +} + func newServeCmd() *cobra.Command { cmd := &cobra.Command{ Use: "serve", @@ -127,8 +142,14 @@ func runServe(cmd *cobra.Command, args []string) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ctrl := rpcserver.NewIntervalController(5*time.Second, 30*time.Second) + fast := pollInterval(os.Getenv, envPollFast, 5*time.Second) + idle := pollInterval(os.Getenv, envPollIdle, 30*time.Second) + contactInterval := pollInterval(os.Getenv, envPollContacts, 60*time.Second) + + ctrl := rpcserver.NewIntervalController(fast, idle) + contactCtrl := rpcserver.NewIntervalController(contactInterval, contactInterval) rpcserver.RegisterFocusHandler(d, ctrl) + rpcserver.RegisterPollConfigHandler(d, ctrl, contactCtrl) sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) @@ -186,7 +207,7 @@ func runServe(cmd *cobra.Command, args []string) error { // ponytail: contacts have no server-side time filter — refetch+diff on a slow cadence. cp := rpcserver.NewContactPoller(client, grantID, srv.Broadcast) - startPoller("contact", func() error { return cp.Run(ctx, 60*time.Second, onErr) }) + startPoller("contact", func() error { return rpcserver.RunAdaptive(ctx, contactCtrl, onErr, cp.PollOnce) }) } _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Nylas RPC WebSocket listening on %s\n", addr) diff --git a/internal/cli/rpc/serve_test.go b/internal/cli/rpc/serve_test.go new file mode 100644 index 0000000..5e3ae2d --- /dev/null +++ b/internal/cli/rpc/serve_test.go @@ -0,0 +1,32 @@ +package rpc + +import ( + "testing" + "time" +) + +func TestPollInterval(t *testing.T) { + const def = 30 * time.Second + + tests := []struct { + name string + env map[string]string + want time.Duration + }{ + {name: "unset falls back to default", env: nil, want: def}, + {name: "empty falls back to default", env: map[string]string{"K": ""}, want: def}, + {name: "valid duration is used", env: map[string]string{"K": "2s"}, want: 2 * time.Second}, + {name: "invalid duration falls back", env: map[string]string{"K": "nope"}, want: def}, + {name: "zero falls back", env: map[string]string{"K": "0s"}, want: def}, + {name: "negative falls back", env: map[string]string{"K": "-5s"}, want: def}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + getenv := func(key string) string { return tt.env[key] } + if got := pollInterval(getenv, "K", def); got != tt.want { + t.Fatalf("pollInterval() = %v, want %v", got, tt.want) + } + }) + } +}