diff --git a/.cargo/config.toml b/.cargo/config.toml index 18770323..4f8a72de 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -19,8 +19,4 @@ SQLX_OFFLINE = "true" # To enable sccache (caches across `cargo clean`): # export RUSTC_WRAPPER=sccache # -# Documented here so contributors can find the levers without grepping CI. - -[build] -# The default `target/` location is fine. Override with CARGO_TARGET_DIR if you -# need a separate cache per branch (useful when bisecting feature flags). +# Documented here so contributors can find the levers without grepping CI. \ No newline at end of file diff --git a/.github/workflows/chatops.yml b/.github/workflows/chatops.yml index 4e6264c2..6fbfc9b8 100644 --- a/.github/workflows/chatops.yml +++ b/.github/workflows/chatops.yml @@ -4,10 +4,7 @@ on: issue_comment: types: [created] -permissions: - contents: write - pull-requests: write - issues: write +permissions: {} jobs: parse: @@ -16,6 +13,9 @@ jobs: github.event.issue.pull_request && startsWith(github.event.comment.body, '/') runs-on: ubuntu-latest + permissions: + issues: write + pull-requests: read outputs: cmd: ${{ steps.parse.outputs.cmd }} matrix: ${{ steps.parse.outputs.matrix }} @@ -25,10 +25,23 @@ jobs: - name: Gate on author association env: ASSOC: ${{ github.event.comment.author_association }} + BODY: ${{ github.event.comment.body }} run: | - case "$ASSOC" in - OWNER|MEMBER|COLLABORATOR) ;; - *) echo "::error::author_association=$ASSOC not allowed"; exit 1 ;; + first_line=$(printf '%s' "$BODY" | head -n1) + cmd=$(printf '%s' "$first_line" | awk '{print $1}') + case "$cmd" in + /squash-merge) + case "$ASSOC" in + OWNER|MEMBER) ;; + *) echo "::error::author_association=$ASSOC not allowed for /squash-merge"; exit 1 ;; + esac + ;; + *) + case "$ASSOC" in + OWNER|MEMBER|COLLABORATOR) ;; + *) echo "::error::author_association=$ASSOC not allowed"; exit 1 ;; + esac + ;; esac - name: Parse @@ -66,7 +79,7 @@ jobs: - name: Resolve PR head id: pr - uses: actions/github-script@v7 + uses: actions/github-script@v7.0.1 with: script: | const pr = await github.rest.pulls.get({ @@ -78,7 +91,7 @@ jobs: core.setOutput('ref', pr.data.head.ref); - name: React to comment - uses: actions/github-script@v7 + uses: actions/github-script@v7.0.1 with: script: | await github.rest.reactions.createForIssueComment({ @@ -107,6 +120,10 @@ jobs: if: needs.parse.outputs.cmd == '/squash-merge' runs-on: ubuntu-latest timeout-minutes: 5 + permissions: + contents: write + pull-requests: write + issues: write steps: - name: Enable auto-merge env: @@ -117,7 +134,7 @@ jobs: - name: Comment if: always() - uses: actions/github-script@v7 + uses: actions/github-script@v7.0.1 with: script: | const result = '${{ job.status }}'; @@ -137,8 +154,10 @@ jobs: needs: [parse, test-template] if: always() && needs.parse.result == 'success' && needs.parse.outputs.cmd == '/test-template' runs-on: ubuntu-latest + permissions: + issues: write steps: - - uses: actions/github-script@v7 + - uses: actions/github-script@v7.0.1 with: script: | const result = '${{ needs.test-template.result }}'; @@ -150,3 +169,22 @@ jobs: issue_number: context.issue.number, body: `${emoji} \`/test-template\` ${result} — [run](${runUrl})`, }); + + report-error: + name: Report parse error + needs: parse + if: always() && needs.parse.result != 'success' + runs-on: ubuntu-latest + permissions: + issues: write + steps: + - uses: actions/github-script@v7.0.1 + with: + script: | + const runUrl = `${context.serverUrl}/${context.repo.owner}/${context.repo.repo}/actions/runs/${context.runId}`; + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + body: `❌ command rejected (gate or parse failed) — [run](${runUrl})`, + }); diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 779563ee..3c5eaa4f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -36,14 +36,15 @@ jobs: runs-on: ubuntu-latest timeout-minutes: 25 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v4.2.2 - uses: dtolnay/rust-toolchain@stable with: components: rustfmt, clippy - - uses: Swatinem/rust-cache@v2 + - uses: Swatinem/rust-cache@v2.7.7 with: shared-key: ci cache-on-failure: true + save-if: ${{ github.ref == 'refs/heads/main' }} - name: fmt run: cargo fmt --all --check @@ -64,9 +65,9 @@ jobs: matrix: preset: [worker, api, minimal] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v4.2.2 - uses: dtolnay/rust-toolchain@stable - - uses: Swatinem/rust-cache@v2 + - uses: Swatinem/rust-cache@v2.7.7 with: shared-key: ci save-if: "false" @@ -79,12 +80,13 @@ jobs: runs-on: ubuntu-latest timeout-minutes: 20 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v4.2.2 - uses: dtolnay/rust-toolchain@1.92 - - uses: Swatinem/rust-cache@v2 + - uses: Swatinem/rust-cache@v2.7.7 with: shared-key: msrv cache-on-failure: true + save-if: ${{ github.ref == 'refs/heads/main' }} - run: scripts/ci/create-env-stubs.sh - run: cargo check --workspace --all-features @@ -93,16 +95,19 @@ jobs: runs-on: ubuntu-latest timeout-minutes: 10 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v4.2.2 - uses: dtolnay/rust-toolchain@stable - - uses: Swatinem/rust-cache@v2 + - uses: Swatinem/rust-cache@v2.7.7 with: shared-key: ci save-if: "false" - - uses: taiki-e/install-action@v2 + - uses: taiki-e/install-action@v2.49.10 with: - tool: cargo-deny,cargo-audit + # Pin exact versions: an unpinned cargo-deny drifted to a release that + # rejects deny.toml's `unmaintained = "workspace"` key, failing the job + # while local (0.19.5) passed. These are the versions verified locally. + tool: cargo-deny@0.19.5,cargo-audit@0.22.1 - name: cargo deny run: cargo deny check @@ -129,14 +134,14 @@ jobs: --health-timeout 5s --health-retries 10 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v4.2.2 - uses: dtolnay/rust-toolchain@stable - - uses: Swatinem/rust-cache@v2 + - uses: Swatinem/rust-cache@v2.7.7 with: shared-key: ci save-if: "false" - name: Install sqlx-cli - run: cargo install sqlx-cli --no-default-features --features rustls,postgres + run: cargo install sqlx-cli --locked --version 0.8.6 --no-default-features --features rustls,postgres - name: Apply system migrations run: | for f in crates/forge-runtime/migrations/system/v*.sql; do @@ -174,13 +179,13 @@ jobs: runs-on: ubuntu-latest timeout-minutes: 20 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v4.2.2 - uses: dtolnay/rust-toolchain@stable - - uses: Swatinem/rust-cache@v2 + - uses: Swatinem/rust-cache@v2.7.7 with: shared-key: ci save-if: "false" - - uses: oven-sh/setup-bun@v2 + - uses: oven-sh/setup-bun@v2.0.2 - run: scripts/ci/create-env-stubs.sh - name: Typecheck generated frontend bindings run: scripts/ci/typecheck-codegen.sh @@ -191,7 +196,7 @@ jobs: name: Workspace integration needs: [validate, guardrails] runs-on: ubuntu-latest - timeout-minutes: 20 + timeout-minutes: 30 services: postgres: image: postgres:18 @@ -209,9 +214,9 @@ jobs: env: TEST_DATABASE_URL: postgres://postgres:postgres@localhost:5432/forge_ci steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v4.2.2 - uses: dtolnay/rust-toolchain@stable - - uses: Swatinem/rust-cache@v2 + - uses: Swatinem/rust-cache@v2.7.7 with: shared-key: ci save-if: "false" @@ -219,6 +224,15 @@ jobs: - run: cargo test -p todo --features testcontainers - run: cargo test -p todo-dioxus --features testcontainers - run: cargo test -p forge-harness --features testcontainers + # forge-runtime's own integration suite (jobs queue, change_log, leader + # election, migration runner, signals, KV) — previously never run in CI, + # so it had rotted. Two non-obvious requirements: + # * `full` — every subsystem is #[cfg(feature)]-gated; bare + # `testcontainers` compiles and runs ZERO subsystem tests (silent no-op). + # * `--test-threads=1` — these tests exercise PG instance-global state + # (advisory locks, pg_terminate_backend, leader election) that cannot + # run concurrently against one shared database without interfering. + - run: cargo test -p forge-runtime --features "full,testcontainers" -- --test-threads=1 pr-smoke: name: PR Smoke diff --git a/.github/workflows/deploy-docs.yml b/.github/workflows/deploy-docs.yml index da8e86d0..ebb6b840 100644 --- a/.github/workflows/deploy-docs.yml +++ b/.github/workflows/deploy-docs.yml @@ -23,11 +23,11 @@ jobs: run: working-directory: docs steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v4.2.2 - - uses: oven-sh/setup-bun@v2 + - uses: oven-sh/setup-bun@v2.0.2 with: - bun-version: latest + bun-version: 1.1.34 - name: Install dependencies run: bun install --frozen-lockfile @@ -36,7 +36,7 @@ jobs: run: bun run build - name: Upload artifact - uses: actions/upload-pages-artifact@v3 + uses: actions/upload-pages-artifact@v3.0.1 with: path: docs/build @@ -49,4 +49,4 @@ jobs: steps: - name: Deploy to GitHub Pages id: deployment - uses: actions/deploy-pages@v4 + uses: actions/deploy-pages@v4.0.5 diff --git a/.github/workflows/docker-otel-lgtm.yml b/.github/workflows/docker-otel-lgtm.yml index 352978bc..bd9fbdb5 100644 --- a/.github/workflows/docker-otel-lgtm.yml +++ b/.github/workflows/docker-otel-lgtm.yml @@ -19,7 +19,7 @@ jobs: name: Build & Push runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v4.2.2 - name: Extract otel-lgtm version from Dockerfile id: version @@ -29,7 +29,7 @@ jobs: echo "otel-lgtm version: $VERSION" - name: Log in to GHCR - uses: docker/login-action@v3 + uses: docker/login-action@v3.3.0 with: registry: ghcr.io username: ${{ github.actor }} diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 01f62dad..73abe7c2 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -29,7 +29,7 @@ jobs: release_notes: ${{ steps.parse.outputs.release_notes }} is_prerelease: ${{ steps.parse.outputs.is_prerelease }} steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v4.2.2 - id: parse run: bash scripts/ci/parse-changelog.sh @@ -37,19 +37,25 @@ jobs: name: Validate runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v4.2.2 - uses: dtolnay/rust-toolchain@stable with: components: rustfmt, clippy - - uses: Swatinem/rust-cache@v2 + - uses: Swatinem/rust-cache@v2.7.7 with: shared-key: workspace + save-if: ${{ github.ref == 'refs/heads/main' }} - run: scripts/ci/create-env-stubs.sh + - uses: taiki-e/install-action@v2.49.10 + with: + tool: cargo-deny,cargo-audit + - run: cargo deny check + - run: cargo audit --deny warnings - run: cargo fmt --all --check - run: cargo clippy --all-targets --all-features --workspace -- -D warnings - run: cargo test --workspace - run: cargo build -p forgex - - uses: actions/upload-artifact@v4 + - uses: actions/upload-artifact@v4.4.3 with: name: forge-cli path: target/debug/forge @@ -70,20 +76,21 @@ jobs: - with-dioxus/demo - with-dioxus/realtime-todo-list steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v4.2.2 - uses: dtolnay/rust-toolchain@stable - - uses: Swatinem/rust-cache@v2 + - uses: Swatinem/rust-cache@v2.7.7 with: shared-key: template-build workspaces: /tmp/test-project -> target - - uses: oven-sh/setup-bun@v2 - - uses: actions/download-artifact@v4 + save-if: ${{ github.ref == 'refs/heads/main' }} + - uses: oven-sh/setup-bun@v2.0.2 + - uses: actions/download-artifact@v4.1.8 with: name: forge-cli path: /tmp/forge-bin - run: chmod +x /tmp/forge-bin/forge - if: startsWith(matrix.template, 'with-dioxus/') - uses: cargo-bins/cargo-binstall@main + uses: cargo-bins/cargo-binstall@v1.10.20 - if: startsWith(matrix.template, 'with-dioxus/') run: cargo binstall dioxus-cli@0.7.5 --no-confirm - run: scripts/ci/test-template.sh "${{ matrix.template }}" /tmp/forge-bin/forge "${{ github.workspace }}" @@ -91,7 +98,7 @@ jobs: run: echo "ARTIFACT_NAME=playwright-${TEMPLATE//\//-}" >> "$GITHUB_ENV" env: TEMPLATE: ${{ matrix.template }} - - uses: actions/upload-artifact@v4 + - uses: actions/upload-artifact@v4.4.3 if: failure() with: name: ${{ env.ARTIFACT_NAME }} @@ -105,14 +112,15 @@ jobs: outputs: commit_sha: ${{ steps.commit.outputs.sha }} steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v4.2.2 with: token: ${{ secrets.GITHUB_TOKEN }} - uses: dtolnay/rust-toolchain@stable - - uses: Swatinem/rust-cache@v2 + - uses: Swatinem/rust-cache@v2.7.7 with: shared-key: workspace - - run: cargo install cargo-edit + save-if: ${{ github.ref == 'refs/heads/main' }} + - run: cargo install cargo-edit --locked --version 0.13.6 - run: scripts/ci/bump-versions.sh "${{ needs.parse-changelog.outputs.version }}" - id: commit run: | @@ -128,6 +136,14 @@ jobs: echo "sha=$(git rev-parse HEAD)" >> $GITHUB_OUTPUT else git commit -m "Bump version to ${{ needs.parse-changelog.outputs.version }}" + # Pre-flight: if branch protection rejects direct pushes, surface + # that as a loud error instead of letting `git push` fail with a + # generic "remote rejected" line buried in the log. + REPO="${GITHUB_REPOSITORY:-${{ github.repository }}}" + if gh api "repos/$REPO/branches/main/protection" >/dev/null 2>&1; then + echo "::error::main has branch protection; the release pipeline must push the version bump via a PR or have the bot enrolled in the protection bypass list." >&2 + exit 1 + fi git push origin main echo "sha=$(git rev-parse HEAD)" >> $GITHUB_OUTPUT fi @@ -146,8 +162,10 @@ jobs: os: macos-15 - target: aarch64-apple-darwin os: macos-14 + - target: aarch64-unknown-linux-gnu + os: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v4.2.2 with: ref: ${{ needs.bump-versions.outputs.commit_sha }} - if: runner.os == 'Linux' @@ -155,12 +173,17 @@ jobs: - uses: dtolnay/rust-toolchain@stable with: targets: ${{ matrix.target }} - - uses: Swatinem/rust-cache@v2 + - uses: Swatinem/rust-cache@v2.7.7 with: - key: ${{ matrix.target }} + shared-key: release-${{ matrix.target }} + save-if: ${{ startsWith(github.ref, 'refs/tags/v') || github.ref == 'refs/heads/main' }} + - if: matrix.target == 'aarch64-unknown-linux-gnu' + run: sudo apt-get update && sudo apt-get install -y gcc-aarch64-linux-gnu - run: cargo build --release --target ${{ matrix.target }} -p forgex + env: + CARGO_TARGET_AARCH64_UNKNOWN_LINUX_GNU_LINKER: aarch64-linux-gnu-gcc - run: cd target/${{ matrix.target }}/release && tar czvf ../../../${{ env.BINARY_NAME }}-${{ matrix.target }}.tar.gz ${{ env.BINARY_NAME }} - - uses: actions/upload-artifact@v4 + - uses: actions/upload-artifact@v4.4.3 with: name: ${{ env.BINARY_NAME }}-${{ matrix.target }} path: ${{ env.BINARY_NAME }}-${{ matrix.target }}.tar.gz @@ -171,7 +194,7 @@ jobs: needs: [parse-changelog, bump-versions, build] runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v4.2.2 with: ref: ${{ needs.bump-versions.outputs.commit_sha }} - run: | @@ -184,11 +207,11 @@ jobs: git tag -a "$TAG" -m "Release $TAG" git push origin "$TAG" fi - - uses: actions/download-artifact@v4 + - uses: actions/download-artifact@v4.1.8 with: path: artifacts - run: cd artifacts && for dir in */; do cd "$dir" && for f in *; do sha256sum "$f" > "$f.sha256"; done && cd ..; done - - uses: softprops/action-gh-release@v2 + - uses: softprops/action-gh-release@v2.2.1 with: tag_name: v${{ needs.parse-changelog.outputs.version }} name: ${{ needs.parse-changelog.outputs.version }} - ${{ needs.parse-changelog.outputs.release_date }} @@ -203,17 +226,20 @@ jobs: needs: [parse-changelog, bump-versions, release] runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v4.2.2 with: ref: ${{ needs.bump-versions.outputs.commit_sha }} - uses: dtolnay/rust-toolchain@stable - - uses: Swatinem/rust-cache@v2 + - uses: Swatinem/rust-cache@v2.7.7 + with: + shared-key: workspace + save-if: ${{ github.ref == 'refs/heads/main' }} - env: - TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }} + CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }} run: | publish_crate() { local output - output=$(cargo publish "$@" --token $TOKEN 2>&1) || { + output=$(cargo publish "$@" 2>&1) || { if echo "$output" | grep -Eq "already uploaded|already exists on crates.io"; then echo "$output" | tail -3 echo "(treated as success: version already on crates.io)" @@ -229,6 +255,13 @@ jobs: for crate_dir in crates/forge-core crates/forge-runtime crates/forge; do cp -r .sqlx "$crate_dir/.sqlx" done + # Guard --allow-dirty: only .sqlx/ paths may appear in the working tree + dirty=$(git status --porcelain | awk '{print $2}' | grep -v '\.sqlx/' || true) + if [ -n "$dirty" ]; then + echo "::error::working tree has unexpected dirty paths beyond .sqlx/:" + echo "$dirty" + exit 1 + fi wait_for_crate() { local crate=$1 version=$2 elapsed=0 @@ -260,10 +293,10 @@ jobs: needs: [parse-changelog, bump-versions, release] runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v4.2.2 with: ref: ${{ needs.bump-versions.outputs.commit_sha }} - - uses: actions/setup-node@v4 + - uses: actions/setup-node@v4.1.0 with: node-version: "22" registry-url: https://registry.npmjs.org diff --git a/.github/workflows/template-smoke.yml b/.github/workflows/template-smoke.yml index f351617a..606e040d 100644 --- a/.github/workflows/template-smoke.yml +++ b/.github/workflows/template-smoke.yml @@ -36,21 +36,22 @@ jobs: runs-on: ubuntu-latest timeout-minutes: ${{ inputs.timeout-minutes }} steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v4.2.2 with: ref: ${{ inputs.ref }} - uses: dtolnay/rust-toolchain@stable - - uses: Swatinem/rust-cache@v2 + - uses: Swatinem/rust-cache@v2.7.7 with: shared-key: template-build workspaces: . -> target - - uses: oven-sh/setup-bun@v2 + save-if: ${{ (inputs.ref == '' && github.ref == 'refs/heads/main') }} + - uses: oven-sh/setup-bun@v2.0.2 - name: Build forge CLI run: cargo build -p forgex - if: startsWith(inputs.template, 'with-dioxus/') - uses: cargo-bins/cargo-binstall@main + uses: cargo-bins/cargo-binstall@v1.10.20 - if: startsWith(inputs.template, 'with-dioxus/') run: cargo binstall dioxus-cli@0.7.5 --no-confirm @@ -59,7 +60,7 @@ jobs: - if: failure() run: echo "ARTIFACT_SLUG=$(echo '${{ inputs.template }}' | tr '/' '-')" >> "$GITHUB_ENV" - if: failure() - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v4.4.3 with: name: playwright-${{ env.ARTIFACT_SLUG }} path: /tmp/forge-test-artifacts/** diff --git a/.sqlx/query-0ef63404257f0212092be1612e8f63c641c19b14ceb671016d66310a74dbca26.json b/.sqlx/query-0ef63404257f0212092be1612e8f63c641c19b14ceb671016d66310a74dbca26.json new file mode 100644 index 00000000..fd027d65 --- /dev/null +++ b/.sqlx/query-0ef63404257f0212092be1612e8f63c641c19b14ceb671016d66310a74dbca26.json @@ -0,0 +1,46 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT * FROM todos WHERE user_id = $1 ORDER BY created_at DESC", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "user_id", + "type_info": "Uuid" + }, + { + "ordinal": 2, + "name": "title", + "type_info": "Text" + }, + { + "ordinal": 3, + "name": "completed", + "type_info": "Bool" + }, + { + "ordinal": 4, + "name": "created_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Uuid" + ] + }, + "nullable": [ + false, + false, + false, + false, + false + ] + }, + "hash": "0ef63404257f0212092be1612e8f63c641c19b14ceb671016d66310a74dbca26" +} diff --git a/.sqlx/query-150265ae3b6ca49ac161e6270ca2f65d8c22b78be7936f621513e6f350629576.json b/.sqlx/query-150265ae3b6ca49ac161e6270ca2f65d8c22b78be7936f621513e6f350629576.json deleted file mode 100644 index 23834a85..00000000 --- a/.sqlx/query-150265ae3b6ca49ac161e6270ca2f65d8c22b78be7936f621513e6f350629576.json +++ /dev/null @@ -1,17 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n INSERT INTO forge_workflow_definitions (workflow_name, workflow_version, workflow_signature, status)\n VALUES ($1, $2, $3, $4)\n ", - "describe": { - "columns": [], - "parameters": { - "Left": [ - "Varchar", - "Varchar", - "Varchar", - "Varchar" - ] - }, - "nullable": [] - }, - "hash": "150265ae3b6ca49ac161e6270ca2f65d8c22b78be7936f621513e6f350629576" -} diff --git a/.sqlx/query-1781c4ea22f4341bc53a015f3a4e0a9f4d31ef6b6b3b28d3d14249215f6b2a47.json b/.sqlx/query-1781c4ea22f4341bc53a015f3a4e0a9f4d31ef6b6b3b28d3d14249215f6b2a47.json deleted file mode 100644 index 8cb406c8..00000000 --- a/.sqlx/query-1781c4ea22f4341bc53a015f3a4e0a9f4d31ef6b6b3b28d3d14249215f6b2a47.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n SELECT workflow_signature FROM forge_workflow_definitions\n WHERE workflow_name = $1 AND workflow_version = $2\n ", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "workflow_signature", - "type_info": "Varchar" - } - ], - "parameters": { - "Left": [ - "Text", - "Text" - ] - }, - "nullable": [ - false - ] - }, - "hash": "1781c4ea22f4341bc53a015f3a4e0a9f4d31ef6b6b3b28d3d14249215f6b2a47" -} diff --git a/examples/with-dioxus/realtime-todo-list/.sqlx/query-183ad1d8316ef2ae5ac6ae4811b8a2bdbaeabbe137a871e26741a419a1aa5b19.json b/.sqlx/query-2e465c3f5f3b3fb29f51cefabfab678ae2f60db0bfdbd437fb00618618b0ab5e.json similarity index 50% rename from examples/with-dioxus/realtime-todo-list/.sqlx/query-183ad1d8316ef2ae5ac6ae4811b8a2bdbaeabbe137a871e26741a419a1aa5b19.json rename to .sqlx/query-2e465c3f5f3b3fb29f51cefabfab678ae2f60db0bfdbd437fb00618618b0ab5e.json index fb994d1e..5b63fbe4 100644 --- a/examples/with-dioxus/realtime-todo-list/.sqlx/query-183ad1d8316ef2ae5ac6ae4811b8a2bdbaeabbe137a871e26741a419a1aa5b19.json +++ b/.sqlx/query-2e465c3f5f3b3fb29f51cefabfab678ae2f60db0bfdbd437fb00618618b0ab5e.json @@ -1,14 +1,15 @@ { "db_name": "PostgreSQL", - "query": "DELETE FROM todos WHERE id = $1", + "query": "DELETE FROM todos WHERE id = $1 AND user_id = $2", "describe": { "columns": [], "parameters": { "Left": [ + "Uuid", "Uuid" ] }, "nullable": [] }, - "hash": "183ad1d8316ef2ae5ac6ae4811b8a2bdbaeabbe137a871e26741a419a1aa5b19" + "hash": "2e465c3f5f3b3fb29f51cefabfab678ae2f60db0bfdbd437fb00618618b0ab5e" } diff --git a/.sqlx/query-3951f9f43cc083105bec6cd76e07b242d762ad2651d81375833dc7ffba3a6097.json b/.sqlx/query-3951f9f43cc083105bec6cd76e07b242d762ad2651d81375833dc7ffba3a6097.json new file mode 100644 index 00000000..a23707e6 --- /dev/null +++ b/.sqlx/query-3951f9f43cc083105bec6cd76e07b242d762ad2651d81375833dc7ffba3a6097.json @@ -0,0 +1,17 @@ +{ + "db_name": "PostgreSQL", + "query": "INSERT INTO users (id, email, name, password_hash) VALUES ($1, $2, $3, $4)", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Varchar", + "Varchar", + "Text" + ] + }, + "nullable": [] + }, + "hash": "3951f9f43cc083105bec6cd76e07b242d762ad2651d81375833dc7ffba3a6097" +} diff --git a/.sqlx/query-493f292d73aecf0bfb68ddd24a9334646c3690e1b30807346d67d9d7b3d21ccb.json b/.sqlx/query-493f292d73aecf0bfb68ddd24a9334646c3690e1b30807346d67d9d7b3d21ccb.json deleted file mode 100644 index f74d31c7..00000000 --- a/.sqlx/query-493f292d73aecf0bfb68ddd24a9334646c3690e1b30807346d67d9d7b3d21ccb.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "UPDATE forge_workflow_definitions SET status = $3 WHERE workflow_name = $1 AND workflow_version = $2", - "describe": { - "columns": [], - "parameters": { - "Left": [ - "Text", - "Text", - "Varchar" - ] - }, - "nullable": [] - }, - "hash": "493f292d73aecf0bfb68ddd24a9334646c3690e1b30807346d67d9d7b3d21ccb" -} diff --git a/.sqlx/query-4aaff6edf11f4e43ee07cf8f58ebeb5479ac15d11cb1c4fdfd0cd7247f519f3c.json b/.sqlx/query-4aaff6edf11f4e43ee07cf8f58ebeb5479ac15d11cb1c4fdfd0cd7247f519f3c.json new file mode 100644 index 00000000..e2ac3226 --- /dev/null +++ b/.sqlx/query-4aaff6edf11f4e43ee07cf8f58ebeb5479ac15d11cb1c4fdfd0cd7247f519f3c.json @@ -0,0 +1,57 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO users (id, email, name, password_hash, created_at, updated_at)\n VALUES ($1, $2, $3, $4, $5, $6)\n RETURNING id, email, name, password_hash as \"password_hash!\", created_at, updated_at\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "email", + "type_info": "Varchar" + }, + { + "ordinal": 2, + "name": "name", + "type_info": "Varchar" + }, + { + "ordinal": 3, + "name": "password_hash!", + "type_info": "Text" + }, + { + "ordinal": 4, + "name": "created_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 5, + "name": "updated_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Uuid", + "Varchar", + "Varchar", + "Text", + "Timestamptz", + "Timestamptz" + ] + }, + "nullable": [ + false, + false, + false, + true, + false, + false + ] + }, + "hash": "4aaff6edf11f4e43ee07cf8f58ebeb5479ac15d11cb1c4fdfd0cd7247f519f3c" +} diff --git a/.sqlx/query-6f810436b0b1e5e2e79b283ce91992a65a3d646156b70a30a3377a4e0f19f1f3.json b/.sqlx/query-6f810436b0b1e5e2e79b283ce91992a65a3d646156b70a30a3377a4e0f19f1f3.json new file mode 100644 index 00000000..3e00e871 --- /dev/null +++ b/.sqlx/query-6f810436b0b1e5e2e79b283ce91992a65a3d646156b70a30a3377a4e0f19f1f3.json @@ -0,0 +1,52 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT id, email, name, password_hash as \"password_hash!\", created_at, updated_at\n FROM users WHERE email = $1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "email", + "type_info": "Varchar" + }, + { + "ordinal": 2, + "name": "name", + "type_info": "Varchar" + }, + { + "ordinal": 3, + "name": "password_hash!", + "type_info": "Text" + }, + { + "ordinal": 4, + "name": "created_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 5, + "name": "updated_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Text" + ] + }, + "nullable": [ + false, + false, + false, + true, + false, + false + ] + }, + "hash": "6f810436b0b1e5e2e79b283ce91992a65a3d646156b70a30a3377a4e0f19f1f3" +} diff --git a/.sqlx/query-8a6d744259a838bf1967d84007dd830892161b34d71fec6a4edac26c370c0b25.json b/.sqlx/query-8a6d744259a838bf1967d84007dd830892161b34d71fec6a4edac26c370c0b25.json deleted file mode 100644 index 06fd6635..00000000 --- a/.sqlx/query-8a6d744259a838bf1967d84007dd830892161b34d71fec6a4edac26c370c0b25.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n UPDATE forge_workflow_runs\n SET cancel_requested_at = NOW(),\n cancel_reason = $2\n WHERE id = $1\n AND status IN ('pending', 'running', 'sleeping', 'waiting')\n AND cancel_requested_at IS NULL\n ", - "describe": { - "columns": [], - "parameters": { - "Left": [ - "Uuid", - "Text" - ] - }, - "nullable": [] - }, - "hash": "8a6d744259a838bf1967d84007dd830892161b34d71fec6a4edac26c370c0b25" -} diff --git a/.sqlx/query-95d674ba7a93b42267ac370fd1ef7bd31bf977f327795fba1a1145711411cfa2.json b/.sqlx/query-95d674ba7a93b42267ac370fd1ef7bd31bf977f327795fba1a1145711411cfa2.json deleted file mode 100644 index 16573921..00000000 --- a/.sqlx/query-95d674ba7a93b42267ac370fd1ef7bd31bf977f327795fba1a1145711411cfa2.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n SELECT id FROM forge_jobs\n WHERE idempotency_key = $1\n AND status NOT IN ('completed', 'failed', 'dead_letter', 'cancelled')\n ", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "id", - "type_info": "Uuid" - } - ], - "parameters": { - "Left": [ - "Text" - ] - }, - "nullable": [ - false - ] - }, - "hash": "95d674ba7a93b42267ac370fd1ef7bd31bf977f327795fba1a1145711411cfa2" -} diff --git a/.sqlx/query-9a81d60e030f76bb21d7de9679748a2e4f8543cd5bdc163fc9998a0dbbc18dd8.json b/.sqlx/query-9a81d60e030f76bb21d7de9679748a2e4f8543cd5bdc163fc9998a0dbbc18dd8.json new file mode 100644 index 00000000..8cd83f1b --- /dev/null +++ b/.sqlx/query-9a81d60e030f76bb21d7de9679748a2e4f8543cd5bdc163fc9998a0dbbc18dd8.json @@ -0,0 +1,14 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE forge_cron_runs\n SET status = 'completed', completed_at = NOW(), error = NULL\n WHERE id = $1 AND status = 'running'\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid" + ] + }, + "nullable": [] + }, + "hash": "9a81d60e030f76bb21d7de9679748a2e4f8543cd5bdc163fc9998a0dbbc18dd8" +} diff --git a/examples/with-dioxus/realtime-todo-list/.sqlx/query-289c71ceebdcb32b1fa7de751cca0918c3286db00bfe90e56cdec7458e1e7b39.json b/.sqlx/query-cd3a68eb363ca38467993dda8a6b904549a2f663613d7570b529eeaa13a6d9aa.json similarity index 68% rename from examples/with-dioxus/realtime-todo-list/.sqlx/query-289c71ceebdcb32b1fa7de751cca0918c3286db00bfe90e56cdec7458e1e7b39.json rename to .sqlx/query-cd3a68eb363ca38467993dda8a6b904549a2f663613d7570b529eeaa13a6d9aa.json index f5b84d9f..6ed720db 100644 --- a/examples/with-dioxus/realtime-todo-list/.sqlx/query-289c71ceebdcb32b1fa7de751cca0918c3286db00bfe90e56cdec7458e1e7b39.json +++ b/.sqlx/query-cd3a68eb363ca38467993dda8a6b904549a2f663613d7570b529eeaa13a6d9aa.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "INSERT INTO todos (title) VALUES ($1) RETURNING *", + "query": "INSERT INTO todos (user_id, title) VALUES ($1, $2) RETURNING *", "describe": { "columns": [ { @@ -10,22 +10,28 @@ }, { "ordinal": 1, + "name": "user_id", + "type_info": "Uuid" + }, + { + "ordinal": 2, "name": "title", "type_info": "Text" }, { - "ordinal": 2, + "ordinal": 3, "name": "completed", "type_info": "Bool" }, { - "ordinal": 3, + "ordinal": 4, "name": "created_at", "type_info": "Timestamptz" } ], "parameters": { "Left": [ + "Uuid", "Text" ] }, @@ -33,8 +39,9 @@ false, false, false, + false, false ] }, - "hash": "289c71ceebdcb32b1fa7de751cca0918c3286db00bfe90e56cdec7458e1e7b39" + "hash": "cd3a68eb363ca38467993dda8a6b904549a2f663613d7570b529eeaa13a6d9aa" } diff --git a/.sqlx/query-c2eda736e5f6342831005dfbd5281fbeb29cb84e74ff483828f9e3ee0fcc517f.json b/.sqlx/query-d1b4c05fc3f85f6412e22e47208952ef8f47e93152275994011c4382e85ed7f2.json similarity index 74% rename from .sqlx/query-c2eda736e5f6342831005dfbd5281fbeb29cb84e74ff483828f9e3ee0fcc517f.json rename to .sqlx/query-d1b4c05fc3f85f6412e22e47208952ef8f47e93152275994011c4382e85ed7f2.json index 1ebed98d..37c21eb7 100644 --- a/.sqlx/query-c2eda736e5f6342831005dfbd5281fbeb29cb84e74ff483828f9e3ee0fcc517f.json +++ b/.sqlx/query-d1b4c05fc3f85f6412e22e47208952ef8f47e93152275994011c4382e85ed7f2.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "UPDATE todos\n SET title = COALESCE($1, title),\n completed = COALESCE($2, completed)\n WHERE id = $3\n RETURNING *", + "query": "UPDATE todos\n SET title = COALESCE($1, title),\n completed = COALESCE($2, completed)\n WHERE id = $3 AND user_id = $4\n RETURNING *", "describe": { "columns": [ { @@ -10,16 +10,21 @@ }, { "ordinal": 1, + "name": "user_id", + "type_info": "Uuid" + }, + { + "ordinal": 2, "name": "title", "type_info": "Text" }, { - "ordinal": 2, + "ordinal": 3, "name": "completed", "type_info": "Bool" }, { - "ordinal": 3, + "ordinal": 4, "name": "created_at", "type_info": "Timestamptz" } @@ -28,6 +33,7 @@ "Left": [ "Text", "Bool", + "Uuid", "Uuid" ] }, @@ -35,8 +41,9 @@ false, false, false, + false, false ] }, - "hash": "c2eda736e5f6342831005dfbd5281fbeb29cb84e74ff483828f9e3ee0fcc517f" + "hash": "d1b4c05fc3f85f6412e22e47208952ef8f47e93152275994011c4382e85ed7f2" } diff --git a/.sqlx/query-e5ef9e9e6b5d1a2b48a29fc639bec8d08e529c732df88396e7fd960c5d45c99b.json b/.sqlx/query-e5ef9e9e6b5d1a2b48a29fc639bec8d08e529c732df88396e7fd960c5d45c99b.json deleted file mode 100644 index fdfa03b6..00000000 --- a/.sqlx/query-e5ef9e9e6b5d1a2b48a29fc639bec8d08e529c732df88396e7fd960c5d45c99b.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "UPDATE forge_workflow_runs SET status = 'running' WHERE id = $1 AND status IN ('pending', 'sleeping', 'waiting', 'running')", - "describe": { - "columns": [], - "parameters": { - "Left": [ - "Uuid" - ] - }, - "nullable": [] - }, - "hash": "e5ef9e9e6b5d1a2b48a29fc639bec8d08e529c732df88396e7fd960c5d45c99b" -} diff --git a/.sqlx/query-e7bce187d4ced5dfaa7bbc448fbb97af559e04fa847bfa0ec37ecc45354d3d22.json b/.sqlx/query-e7bce187d4ced5dfaa7bbc448fbb97af559e04fa847bfa0ec37ecc45354d3d22.json new file mode 100644 index 00000000..a22ee7e8 --- /dev/null +++ b/.sqlx/query-e7bce187d4ced5dfaa7bbc448fbb97af559e04fa847bfa0ec37ecc45354d3d22.json @@ -0,0 +1,52 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT id, email, name, password_hash as \"password_hash!\", created_at, updated_at\n FROM users WHERE id = $1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "email", + "type_info": "Varchar" + }, + { + "ordinal": 2, + "name": "name", + "type_info": "Varchar" + }, + { + "ordinal": 3, + "name": "password_hash!", + "type_info": "Text" + }, + { + "ordinal": 4, + "name": "created_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 5, + "name": "updated_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Uuid" + ] + }, + "nullable": [ + false, + false, + false, + true, + false, + false + ] + }, + "hash": "e7bce187d4ced5dfaa7bbc448fbb97af559e04fa847bfa0ec37ecc45354d3d22" +} diff --git a/.sqlx/query-fe67fe1d5492dd97f324a0fa1a6b73e7b6282402018f150a58eefe67319ba763.json b/.sqlx/query-fe67fe1d5492dd97f324a0fa1a6b73e7b6282402018f150a58eefe67319ba763.json deleted file mode 100644 index 71889f63..00000000 --- a/.sqlx/query-fe67fe1d5492dd97f324a0fa1a6b73e7b6282402018f150a58eefe67319ba763.json +++ /dev/null @@ -1,38 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "SELECT * FROM todos ORDER BY created_at DESC", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "id", - "type_info": "Uuid" - }, - { - "ordinal": 1, - "name": "title", - "type_info": "Text" - }, - { - "ordinal": 2, - "name": "completed", - "type_info": "Bool" - }, - { - "ordinal": 3, - "name": "created_at", - "type_info": "Timestamptz" - } - ], - "parameters": { - "Left": [] - }, - "nullable": [ - false, - false, - false, - false - ] - }, - "hash": "fe67fe1d5492dd97f324a0fa1a6b73e7b6282402018f150a58eefe67319ba763" -} diff --git a/Cargo.lock b/Cargo.lock index 8a8442e8..38fe458a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1300,11 +1300,14 @@ dependencies = [ "dotenvy", "forgex", "futures-util", + "hex", + "hmac", "password-hash", "reqwest", "rust-embed", "serde", "serde_json", + "sha2", "sqlx", "tokio", "tokio-tungstenite 0.26.2", @@ -1356,8 +1359,10 @@ name = "forge-macros" version = "0.10.1" dependencies = [ "blake3", + "chrono-tz", "cron", "darling 0.20.11", + "proc-macro-crate", "proc-macro2", "quote", "sqlparser", @@ -1427,11 +1432,14 @@ dependencies = [ "dotenvy", "forgex", "futures-util", + "hex", + "hmac", "password-hash", "reqwest", "rust-embed", "serde", "serde_json", + "sha2", "sqlx", "tokio", "tokio-tungstenite 0.26.2", @@ -2924,6 +2932,15 @@ dependencies = [ "elliptic-curve", ] +[[package]] +name = "proc-macro-crate" +version = "3.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e67ba7e9b2b56446f1d419b1d807906278ffa1a658a8a5d8a39dcb1f5a78614f" +dependencies = [ + "toml_edit", +] + [[package]] name = "proc-macro2" version = "1.0.106" @@ -4353,9 +4370,11 @@ dependencies = [ name = "todo" version = "0.10.1" dependencies = [ + "argon2", "chrono", "dotenvy", "forgex", + "password-hash", "rust-embed", "serde", "sqlx", @@ -4368,9 +4387,11 @@ dependencies = [ name = "todo-dioxus" version = "0.10.1" dependencies = [ + "argon2", "chrono", "dotenvy", "forgex", + "password-hash", "rust-embed", "serde", "sqlx", @@ -4528,6 +4549,18 @@ dependencies = [ "serde_core", ] +[[package]] +name = "toml_edit" +version = "0.25.12+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2153edc6955a6c354fad8f5efd38b6a8769bdccf9fe50f8e1329f81b0baa5d7" +dependencies = [ + "indexmap 2.14.0", + "toml_datetime 1.1.1+spec-1.1.0", + "toml_parser", + "winnow 1.0.2", +] + [[package]] name = "toml_parser" version = "1.1.2+spec-1.1.0" @@ -5518,6 +5551,9 @@ name = "winnow" version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2ee1708bef14716a11bae175f579062d4554d95be2c6829f518df847b7b3fdd0" +dependencies = [ + "memchr", +] [[package]] name = "winsafe" diff --git a/Cargo.toml b/Cargo.toml index cb4d378a..72b49368 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -86,6 +86,7 @@ opentelemetry-appender-tracing = "0.27" syn = { version = "2.0", features = ["full", "visit"] } darling = { version = "0.20" } quote = "1.0" +proc-macro-crate = "3" inventory = "0.3" sha2 = "0.10" blake3 = "1" diff --git a/LICENSE b/LICENSE index be5edf91..f1d1f943 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2025 Isala Piyarisi +Copyright (c) 2025-2026 Isala Piyarisi Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/crates/forge-codegen/src/binding.rs b/crates/forge-codegen/src/binding.rs index 1aabafbc..3789220e 100644 --- a/crates/forge-codegen/src/binding.rs +++ b/crates/forge-codegen/src/binding.rs @@ -123,7 +123,10 @@ fn build_binding(func: FunctionDef, tables: &[TableDef]) -> FunctionBinding { /// We require BOTH a naming convention match AND existence in the registry. /// This prevents false positives on types like "InputHandler" or "ArgumentParser". fn is_custom_args_type(rust_type: &RustType, tables: &[TableDef]) -> bool { + // Unwrap `Option`/`Vec` wrappers so `Vec` and `Option` + // are recognised as custom-args bindings. match rust_type { + RustType::Option(inner) | RustType::Vec(inner) => is_custom_args_type(inner, tables), RustType::Custom(name) => { (name.ends_with("Args") || name.ends_with("Input")) && tables.iter().any(|t| t.struct_name == *name) diff --git a/crates/forge-codegen/src/dioxus/mod.rs b/crates/forge-codegen/src/dioxus/mod.rs index aef602d9..46a4cd5c 100644 --- a/crates/forge-codegen/src/dioxus/mod.rs +++ b/crates/forge-codegen/src/dioxus/mod.rs @@ -35,6 +35,12 @@ impl DioxusGenerator { } fn mod_content() -> &'static str { + // Framework re-exports are explicit. The api/types globs are kept for + // downstream ergonomics (users reach `get_user(…)` directly), but if a + // user names a type collision-prone (`Mutation`, `QueryState`, …) the + // resolution between `forge_dioxus::Mutation` and the user `Mutation` + // becomes ambiguous. The framework imports are listed explicitly so + // the conflict is at least visible in this file. r#"// @generated by FORGE - DO NOT EDIT #![allow(dead_code, unused_imports)] diff --git a/crates/forge-codegen/src/dioxus/types.rs b/crates/forge-codegen/src/dioxus/types.rs index 350593cb..9a080b66 100644 --- a/crates/forge-codegen/src/dioxus/types.rs +++ b/crates/forge-codegen/src/dioxus/types.rs @@ -11,7 +11,7 @@ use crate::emit::{self, contains_json, contains_upload}; pub fn generate(registry: &SchemaRegistry) -> Result { let mut output = String::from( - "// @generated by FORGE - DO NOT EDIT\n\n#![allow(dead_code, unused_imports, clippy::redundant_field_names, clippy::too_many_arguments)]\n\n", + "// @generated by FORGE - DO NOT EDIT\n\n#![allow(dead_code, unused_imports, clippy::too_many_arguments)]\n\n", ); output.push_str("use serde::{Deserialize, Serialize};\n"); @@ -58,9 +58,11 @@ pub fn generate(registry: &SchemaRegistry) -> Result { fn render_struct(table: &TableDef) -> String { let has_upload = table.fields.iter().any(|f| contains_upload(&f.rust_type)); - // Upload fields cannot derive PartialEq. + // ForgeUpload doesn't impl PartialEq or Serialize/Deserialize, so upload + // fields are skipped from the wire payload. Mutation routing handles them + // out-of-band via multipart. let derives = if has_upload { - "#[derive(Debug, Clone)]\n" + "#[derive(Debug, Clone, Serialize, Deserialize)]\n" } else { "#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]\n" }; @@ -69,6 +71,9 @@ fn render_struct(table: &TableDef) -> String { output.push_str(derives); output.push_str(&format!("pub struct {} {{\n", table.struct_name)); for field in &table.fields { + if contains_upload(&field.rust_type) { + output.push_str(" #[serde(skip)]\n"); + } output.push_str(&format!( " pub {}: {},\n", field.name, @@ -104,11 +109,12 @@ fn render_struct_impl(struct_name: &str, fields: &[FieldDef]) -> String { let mut constructor_body = String::new(); for field in &required_fields { - constructor_body.push_str(&format!( - " {}: {},\n", - field.name, - builder::value_expr(&field.name, &field.rust_type) - )); + let value = builder::value_expr(&field.name, &field.rust_type); + if value == field.name { + constructor_body.push_str(&format!(" {},\n", field.name)); + } else { + constructor_body.push_str(&format!(" {}: {},\n", field.name, value)); + } } for field in &optional_fields { constructor_body.push_str(&format!(" {}: None,\n", field.name)); @@ -181,8 +187,9 @@ mod tests { registry.register_table(table); let output = generate(®istry).expect("upload struct generation should succeed"); - assert!(output.contains("#[derive(Debug, Clone)]")); + assert!(output.contains("#[derive(Debug, Clone, Serialize, Deserialize)]")); assert!(!output.contains("PartialEq")); + assert!(output.contains("#[serde(skip)]")); assert!(output.contains("ForgeUpload")); } diff --git a/crates/forge-codegen/src/emit.rs b/crates/forge-codegen/src/emit.rs index 3c8fbe78..3c15cd10 100644 --- a/crates/forge-codegen/src/emit.rs +++ b/crates/forge-codegen/src/emit.rs @@ -160,7 +160,12 @@ fn dioxus_custom(name: &str) -> String { "Uuid" | "uuid::Uuid" => "String".into(), "DateTime" | "NaiveDate" | "NaiveDateTime" | "Instant" | "LocalDate" | "LocalTime" | "Timestamp" => "String".into(), - "i32" | "u32" | "usize" | "isize" => "i64".into(), + // Preserve narrow integer widths instead of silently widening to i64. + // Mirrors what the handler actually returns on the wire. + "i32" => "i32".into(), + "u32" => "u32".into(), + "usize" => "usize".into(), + "isize" => "isize".into(), "i64" | "u64" => "i64".into(), "f32" => "f32".into(), "f64" => "f64".into(), @@ -371,11 +376,10 @@ mod tests { #[test] fn dioxus_hashmap() { - // Custom-string primitives go through dioxus_custom which widens - // `i32`/`u32`/etc. to `i64`. The HashMap value follows the same path. + // Narrow integers are preserved; the HashMap value follows the same path. assert_eq!( dioxus_type(&RustType::Custom("HashMap".into())), - "std::collections::HashMap" + "std::collections::HashMap" ); assert_eq!( dioxus_type(&RustType::Custom("HashMap".into())), diff --git a/crates/forge-codegen/src/parser.rs b/crates/forge-codegen/src/parser.rs index 7cafa579..239a0415 100644 --- a/crates/forge-codegen/src/parser.rs +++ b/crates/forge-codegen/src/parser.rs @@ -22,7 +22,6 @@ use forge_core::schema::{ use forge_core::util::to_snake_case; use std::collections::BTreeMap; -use quote::ToTokens; use syn::{Attribute, Expr, Fields, FnArg, Lit, Meta, Pat, ReturnType}; use crate::Error; @@ -151,10 +150,11 @@ pub fn find_duplicate_handlers(src_dir: &Path) -> Result Result<(), Error> { match item { syn::Item::Struct(item_struct) => { if has_forge_attr(&item_struct.attrs, "model") { - if let Some(table) = parse_model(&item_struct) { - registry.register_table(table); - } - } else if has_serde_derive(&item_struct.attrs) - && let Some(table) = parse_dto_struct(&item_struct) - { - registry.register_table(table); + registry.register_table(parse_model(&item_struct)?); + } else if has_serde_derive(&item_struct.attrs) { + registry.register_table(parse_dto_struct(&item_struct)?); } } - syn::Item::Enum(item_enum) => { - if (has_forge_enum_attr(&item_enum.attrs) || has_serde_derive(&item_enum.attrs)) - && let Some(enum_def) = parse_enum(&item_enum) - { - registry.register_enum(enum_def); - } + syn::Item::Enum(item_enum) + if has_forge_enum_attr(&item_enum.attrs) || has_serde_derive(&item_enum.attrs) => + { + registry.register_enum(parse_enum(&item_enum)?); } syn::Item::Fn(item_fn) => { - if let Some(func) = parse_function(&item_fn) { - registry.register_function(func); + if let Some(func) = parse_function(&item_fn)? { + register_function_checked(registry, func)?; } } _ => {} @@ -203,65 +197,98 @@ fn parse_file(content: &str, registry: &SchemaRegistry) -> Result<(), Error> { Ok(()) } -/// Check if attributes contain `#[forge::name]` or `#[name]`. +fn register_function_checked(registry: &SchemaRegistry, func: FunctionDef) -> Result<(), Error> { + if let Some(existing) = registry.get_function(&func.name) + && existing.kind != func.kind + { + return Err(Error::Parse { + file: String::new(), + message: format!( + "handler name collision: `{}` is registered as both {:?} and {:?}", + func.name, existing.kind, func.kind + ), + }); + } + registry.register_function(func); + Ok(()) +} + +/// Check if attributes contain `#[name]` or any `#[…::name]` re-import. +/// +/// Matches by last segment so `#[forge::query]`, `#[forge_macros::query]`, +/// and `#[crate::query]` all count. fn has_forge_attr(attrs: &[Attribute], name: &str) -> bool { attrs.iter().any(|attr| { - let path = attr.path(); - path.is_ident(name) - || matches!( - (path.segments.first(), path.segments.get(1), path.segments.get(2)), - (Some(first), Some(second), None) - if first.ident == "forge" && second.ident == name - ) + attr.path() + .segments + .last() + .is_some_and(|seg| seg.ident == name) }) } -/// Check if attributes contain `#[forge_enum]`, `#[enum_type]`, or `#[forge::enum_type]`. +/// Check if attributes contain `#[forge_enum]`, `#[enum_type]`, or any +/// `#[…::forge_enum]` / `#[…::enum_type]` re-import. fn has_forge_enum_attr(attrs: &[Attribute]) -> bool { attrs.iter().any(|attr| { - let path = attr.path(); - path.is_ident("forge_enum") - || path.is_ident("enum_type") - || matches!( - (path.segments.first(), path.segments.get(1), path.segments.get(2)), - (Some(first), Some(second), None) - if first.ident == "forge" - && (second.ident == "enum_type" || second.ident == "forge_enum") - ) + attr.path() + .segments + .last() + .is_some_and(|seg| seg.ident == "forge_enum" || seg.ident == "enum_type") }) } +/// True iff a `#[derive(...)]` attribute names `Serialize` or `Deserialize` +/// as a path segment (not as part of a longer identifier like `MySerialize`). fn has_serde_derive(attrs: &[Attribute]) -> bool { attrs.iter().any(|attr| { if !attr.path().is_ident("derive") { return false; } - let tokens = attr.meta.to_token_stream().to_string(); - tokens.contains("Serialize") || tokens.contains("Deserialize") + let Meta::List(list) = &attr.meta else { + return false; + }; + let mut found = false; + let _ = list.parse_nested_meta(|meta| { + if let Some(seg) = meta.path.segments.last() + && (seg.ident == "Serialize" || seg.ident == "Deserialize") + { + found = true; + } + Ok(()) + }); + found }) } -fn parse_dto_struct(item: &syn::ItemStruct) -> Option { +fn parse_dto_struct(item: &syn::ItemStruct) -> Result { let struct_name = item.ident.to_string(); + reject_unsupported_struct_serde_attrs(&struct_name, &item.attrs)?; + let mut table = TableDef::new(&struct_name, &struct_name); table.is_dto = true; table.doc = get_doc_comment(&item.attrs); - if let Fields::Named(fields) = &item.fields { - for field in &fields.named { - if let Some(field_name) = &field.ident { - table - .fields - .push(parse_field(field_name.to_string(), &field.ty, &field.attrs)); - } + let Fields::Named(fields) = &item.fields else { + return Err(parse_err(format!( + "DTO struct `{struct_name}` must use named fields; tuple and unit structs are not supported by codegen" + ))); + }; + + for field in &fields.named { + if let Some(field_name) = &field.ident + && let Some(parsed) = parse_field(field_name.to_string(), &field.ty, &field.attrs)? + { + table.fields.push(parsed); } } - Some(table) + Ok(table) } -fn parse_model(item: &syn::ItemStruct) -> Option { +fn parse_model(item: &syn::ItemStruct) -> Result { let struct_name = item.ident.to_string(); + reject_unsupported_struct_serde_attrs(&struct_name, &item.attrs)?; + let table_name = get_table_name_from_attrs(&item.attrs).unwrap_or_else(|| { let snake = to_snake_case(&struct_name); pluralize(&snake) @@ -270,33 +297,139 @@ fn parse_model(item: &syn::ItemStruct) -> Option { let mut table = TableDef::new(&table_name, &struct_name); table.doc = get_doc_comment(&item.attrs); - if let Fields::Named(fields) = &item.fields { - for field in &fields.named { - if let Some(field_name) = &field.ident { - table - .fields - .push(parse_field(field_name.to_string(), &field.ty, &field.attrs)); - } + let Fields::Named(fields) = &item.fields else { + return Err(parse_err(format!( + "model `{struct_name}` must use named fields; tuple and unit structs are not supported by codegen" + ))); + }; + + for field in &fields.named { + if let Some(field_name) = &field.ident + && let Some(parsed) = parse_field(field_name.to_string(), &field.ty, &field.attrs)? + { + table.fields.push(parsed); } } - Some(table) + Ok(table) } -fn parse_field(name: String, ty: &syn::Type, attrs: &[Attribute]) -> FieldDef { - let rust_type = type_to_rust_type(ty); - let mut field = FieldDef::new(&name, rust_type); - field.column_name = to_snake_case(&name); +fn parse_field( + name: String, + ty: &syn::Type, + attrs: &[Attribute], +) -> Result, Error> { + let serde = parse_field_serde(&name, attrs)?; + // A serde-skipped field never appears in the JSON wire shape, so it must not + // appear in the generated type. Returning None (rather than erroring) lets a + // struct keep server-only fields like `password_hash` without breaking the + // whole file's bindings. + if serde.skip { + return Ok(None); + } + let rust_type = type_to_rust_type(ty)?; + let final_name = serde.rename.unwrap_or(name.clone()); + let mut field = FieldDef::new(&final_name, rust_type); + field.column_name = to_snake_case(&final_name); field.doc = get_doc_comment(attrs); - field + let _ = sanitize_reserved(&name); // surface reserved-word handling at field site + Ok(Some(field)) +} + +#[derive(Default)] +struct FieldSerdeAttrs { + rename: Option, + /// `#[serde(skip)]` / `skip_serializing` / `skip_deserializing` — the field + /// is absent from the JSON wire shape, so it must be omitted from the + /// generated type rather than failing the whole file. + skip: bool, +} + +/// Parse `#[serde(...)]` directives on a field. Errors on directives codegen +/// cannot honor; honors `rename` and tolerates `default`. +fn parse_field_serde(field_name: &str, attrs: &[Attribute]) -> Result { + let mut out = FieldSerdeAttrs::default(); + for attr in attrs { + if !attr.path().is_ident("serde") { + continue; + } + let Meta::List(list) = &attr.meta else { + continue; + }; + let mut err: Option = None; + let _ = list.parse_nested_meta(|meta| { + let Some(seg) = meta.path.segments.last() else { + return Ok(()); + }; + match seg.ident.to_string().as_str() { + "rename" => { + if let Ok(value) = meta.value() + && let Ok(lit) = value.parse::() + { + out.rename = Some(lit.value()); + } + } + "default" => {} + "skip" | "skip_serializing" | "skip_deserializing" => { + out.skip = true; + } + "flatten" => { + err = Some(parse_err(format!( + "field `{field_name}`: `#[serde(flatten)]` is not supported by codegen" + ))); + } + _ => {} + } + Ok(()) + }); + if let Some(e) = err { + return Err(e); + } + } + Ok(out) +} + +/// Reject struct-level serde directives that codegen cannot faithfully honor. +fn reject_unsupported_struct_serde_attrs(name: &str, attrs: &[Attribute]) -> Result<(), Error> { + for attr in attrs { + if !attr.path().is_ident("serde") { + continue; + } + let Meta::List(list) = &attr.meta else { + continue; + }; + let mut err: Option = None; + let _ = list.parse_nested_meta(|meta| { + if let Some(seg) = meta.path.segments.last() + && (seg.ident == "rename_all" || seg.ident == "tag" || seg.ident == "untagged") + { + err = Some(parse_err(format!( + "struct `{name}`: `#[serde({})]` is not supported by codegen", + seg.ident + ))); + } + Ok(()) + }); + if let Some(e) = err { + return Err(e); + } + } + Ok(()) } -fn parse_enum(item: &syn::ItemEnum) -> Option { +fn parse_enum(item: &syn::ItemEnum) -> Result { let enum_name = item.ident.to_string(); let mut enum_def = EnumDef::new(&enum_name); enum_def.doc = get_doc_comment(&item.attrs); for variant in &item.variants { + if !matches!(variant.fields, Fields::Unit) { + return Err(parse_err(format!( + "enum `{}` variant `{}`: non-unit variants are not yet supported by codegen", + enum_name, variant.ident + ))); + } + let variant_name = variant.ident.to_string(); let mut enum_variant = EnumVariant::new(&variant_name); enum_variant.doc = get_doc_comment(&variant.attrs); @@ -311,16 +444,62 @@ fn parse_enum(item: &syn::ItemEnum) -> Option { enum_def.variants.push(enum_variant); } - Some(enum_def) + Ok(enum_def) +} + +fn parse_err(message: String) -> Error { + Error::Parse { + file: String::new(), + message, + } } -fn parse_function(item: &syn::ItemFn) -> Option { - let kind = get_function_kind(&item.attrs)?; +/// TS and Rust reserved words that survive verbatim through codegen. +/// Returning `None` means the name is fine; `Some(_)` returns a sanitized form +/// (currently used only to provoke a parser-level warning at the call site). +fn sanitize_reserved(name: &str) -> Option { + const RESERVED: &[&str] = &[ + // TS + "type", + "class", + "interface", + "enum", + "default", + "import", + "export", + "function", + "var", + "let", + "const", + "new", + "delete", + // Rust (subset that round-trips into emitted Rust) + "match", + "mod", + "pub", + "fn", + "impl", + "trait", + "use", + "ref", + "move", + ]; + if RESERVED.contains(&name) { + Some(format!("{name}_")) + } else { + None + } +} + +fn parse_function(item: &syn::ItemFn) -> Result, Error> { + let Some(kind) = get_function_kind(&item.attrs) else { + return Ok(None); + }; let func_name = item.sig.ident.to_string(); let return_type = match &item.sig.output { ReturnType::Default => RustType::Custom("()".to_string()), - ReturnType::Type(_, ty) => extract_result_type(ty), + ReturnType::Type(_, ty) => extract_result_type(ty)?, }; let mut func = FunctionDef::new(&func_name, kind, return_type); @@ -339,13 +518,18 @@ fn parse_function(item: &syn::ItemFn) -> Option { if let Pat::Ident(pat_ident) = &*pat_type.pat { let arg_name = pat_ident.ident.to_string(); - let arg_type = type_to_rust_type(&pat_type.ty); + if sanitize_reserved(&arg_name).is_some() { + return Err(parse_err(format!( + "handler `{func_name}` argument `{arg_name}` is a reserved word in TS/Rust; rename it" + ))); + } + let arg_type = type_to_rust_type(&pat_type.ty)?; func.args.push(FunctionArg::new(arg_name, arg_type)); } } } - Some(func) + Ok(Some(func)) } /// Known Forge context types. Only these are skipped as the first parameter. @@ -362,7 +546,13 @@ const KNOWN_CONTEXT_TYPES: &[&str] = &[ ]; /// Check if a type is a known Forge context type. -/// Walks `&T`/`&mut T` references and checks the final path segment. +/// +/// Walks `&T`/`&mut T` references. To avoid false positives on user-defined +/// types named `QueryContext`, requires either a bare path +/// (`QueryContext`) — which is the conventional `use forge::prelude::*` +/// case — or a qualified path beginning with `forge`, `forge_core`, or +/// `crate`. Any other prefix (e.g. `myapp::QueryContext`) is treated as +/// a user type and is NOT stripped from the RPC signature. fn is_context_type(ty: &syn::Type) -> bool { let mut inner = ty; while let syn::Type::Reference(r) = inner { @@ -374,7 +564,16 @@ fn is_context_type(ty: &syn::Type) -> bool { let Some(last) = type_path.path.segments.last() else { return false; }; - KNOWN_CONTEXT_TYPES.contains(&last.ident.to_string().as_str()) + if !KNOWN_CONTEXT_TYPES.contains(&last.ident.to_string().as_str()) { + return false; + } + let segments = &type_path.path.segments; + if segments.len() == 1 { + return true; + } + segments + .first() + .is_some_and(|s| s.ident == "forge" || s.ident == "forge_core" || s.ident == "crate") } fn get_function_kind(attrs: &[Attribute]) -> Option { @@ -403,7 +602,7 @@ fn get_function_kind(attrs: &[Attribute]) -> Option { } /// Extract the inner `T` from `Result`. -fn extract_result_type(ty: &syn::Type) -> RustType { +fn extract_result_type(ty: &syn::Type) -> Result { if let syn::Type::Path(type_path) = ty && let Some(seg) = type_path.path.segments.last() && seg.ident == "Result" @@ -416,21 +615,21 @@ fn extract_result_type(ty: &syn::Type) -> RustType { type_to_rust_type(ty) } -fn type_to_rust_type(ty: &syn::Type) -> RustType { +fn type_to_rust_type(ty: &syn::Type) -> Result { match ty { syn::Type::Reference(r) => type_to_rust_type(&r.elem), syn::Type::Path(tp) => path_to_rust_type(tp), - _ => RustType::Custom(quote::quote!(#ty).to_string()), + _ => Ok(RustType::Custom(quote::quote!(#ty).to_string())), } } -fn path_to_rust_type(tp: &syn::TypePath) -> RustType { +fn path_to_rust_type(tp: &syn::TypePath) -> Result { let Some(last) = tp.path.segments.last() else { - return RustType::Custom(quote::quote!(#tp).to_string()); + return Ok(RustType::Custom(quote::quote!(#tp).to_string())); }; let ident = last.ident.to_string(); - match ident.as_str() { + Ok(match ident.as_str() { "String" | "str" => RustType::String, "i32" => RustType::I32, "i64" => RustType::I64, @@ -443,18 +642,45 @@ fn path_to_rust_type(tp: &syn::TypePath) -> RustType { "NaiveTime" => RustType::LocalTime, "Value" => RustType::Json, "Option" => { - let inner = first_generic_arg(last); + let inner = first_generic_arg(last, "Option")?; RustType::Option(Box::new(inner)) } "Vec" => { if is_vec_u8(last) { - return RustType::Bytes; + return Ok(RustType::Bytes); } - let inner = first_generic_arg(last); + let inner = first_generic_arg(last, "Vec")?; RustType::Vec(Box::new(inner)) } + // `HashMap` / `BTreeMap` are preserved as their full + // textual form so the TS/Dioxus emitters can route through their + // `Custom("HashMap<…>")` branches. Bare `HashMap`/`BTreeMap` is a + // parse error to mirror bare `Vec`/`Option`. + "HashMap" | "BTreeMap" => map_to_rust_type(last, &ident)?, _ => RustType::Custom(ident), - } + }) +} + +fn map_to_rust_type(seg: &syn::PathSegment, name: &str) -> Result { + let syn::PathArguments::AngleBracketed(args) = &seg.arguments else { + return Err(parse_err(format!( + "bare `{name}` is not a valid type; expected `{name}`" + ))); + }; + let mut iter = args.args.iter().filter_map(|a| match a { + syn::GenericArgument::Type(t) => Some(t), + _ => None, + }); + let (Some(key), Some(value)) = (iter.next(), iter.next()) else { + return Err(parse_err(format!( + "`{name}` requires two type parameters ``" + ))); + }; + let key_str = quote::quote!(#key).to_string().replace(' ', ""); + let value_str = quote::quote!(#value).to_string().replace(' ', ""); + // Always normalize to `HashMap<…>` so the emitters' existing + // string-prefix branches fire for both `HashMap` and `BTreeMap`. + Ok(RustType::Custom(format!("HashMap<{key_str}, {value_str}>"))) } fn is_vec_u8(seg: &syn::PathSegment) -> bool { @@ -467,13 +693,15 @@ fn is_vec_u8(seg: &syn::PathSegment) -> bool { false } -fn first_generic_arg(seg: &syn::PathSegment) -> RustType { +fn first_generic_arg(seg: &syn::PathSegment, container: &str) -> Result { if let syn::PathArguments::AngleBracketed(args) = &seg.arguments && let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() { return type_to_rust_type(inner_ty); } - RustType::Custom(seg.ident.to_string()) + Err(parse_err(format!( + "bare `{container}` is not a valid type; expected `{container}`" + ))) } /// Get `#[table(name = "...")]` value from attributes. @@ -573,6 +801,38 @@ mod tests { assert_eq!(table.fields.len(), 4); } + #[test] + fn serde_skip_fields_are_omitted_not_rejected() { + // A field that serde never serializes isn't in the JSON wire shape, so + // it must be dropped from the generated type — and crucially must NOT + // fail the whole file (which would orphan every other type defined in + // it, e.g. a UserPublic next to a password-bearing User model). + let source = r#" + #[derive(serde::Serialize, serde::Deserialize)] + struct UserPublic { + id: Uuid, + email: String, + #[serde(skip_serializing)] + password_hash: Option, + #[serde(skip)] + internal: String, + } + "#; + + let registry = SchemaRegistry::new(); + parse_file(source, ®istry).expect("skip fields must not fail the file"); + + let table = registry + .get_table("UserPublic") + .expect("UserPublic should still be registered"); + let names: Vec<&str> = table.fields.iter().map(|f| f.name.as_str()).collect(); + assert_eq!( + names, + vec!["id", "email"], + "serde-skipped fields must be omitted from the generated type", + ); + } + #[test] fn test_parse_enum_source() { let source = r#" @@ -969,7 +1229,7 @@ mod tests { fn parse_type(s: &str) -> RustType { let ty: syn::Type = syn::parse_str(s).expect("valid type"); - type_to_rust_type(&ty) + type_to_rust_type(&ty).expect("valid type maps to RustType") } #[test] diff --git a/crates/forge-codegen/src/typescript/api.rs b/crates/forge-codegen/src/typescript/api.rs index 6048759d..0202064b 100644 --- a/crates/forge-codegen/src/typescript/api.rs +++ b/crates/forge-codegen/src/typescript/api.rs @@ -11,6 +11,7 @@ use crate::binding::{BindingSet, FunctionBinding}; use crate::emit::{self, Position}; pub fn generate(bindings: &BindingSet) -> Result { + check_store_factory_collisions(bindings)?; let mut output = String::from("// @generated by FORGE - DO NOT EDIT\n\n"); let mut type_imports = Vec::new(); @@ -127,6 +128,30 @@ fn gen_subscription(b: &FunctionBinding) -> String { ) } +/// `gen_store_factory` emits `track{Pascal}` for jobs and workflows; if a +/// user query/mutation is named `track_foo`, both factories would collide +/// on the same `trackFoo` identifier. Fail loudly at codegen rather than +/// emitting duplicate `export const` lines. +fn check_store_factory_collisions(bindings: &BindingSet) -> Result<(), Error> { + use std::collections::HashSet; + let mut user_names: HashSet = HashSet::new(); + for b in bindings.queries.iter().chain(bindings.mutations.iter()) { + user_names.insert(to_camel_case(&b.name)); + } + for b in bindings.jobs.iter().chain(bindings.workflows.iter()) { + let factory = format!("track{}", to_pascal_case(&b.name)); + if user_names.contains(&factory) { + return Err(Error::Generation(format!( + "store factory name `{factory}` (from {kind:?} `{name}`) collides \ + with a user query/mutation of the same camelCase name; rename one of them", + kind = b.kind, + name = b.name, + ))); + } + } + Ok(()) +} + fn gen_store_factory(b: &FunctionBinding, store_fn: &str) -> String { let factory_name = format!("track{}", to_pascal_case(&b.name)); let output_type = emit::ts_type(&b.return_type, Position::Return); diff --git a/crates/forge-codegen/src/typescript/mod.rs b/crates/forge-codegen/src/typescript/mod.rs index f1cce9b6..575a1cff 100644 --- a/crates/forge-codegen/src/typescript/mod.rs +++ b/crates/forge-codegen/src/typescript/mod.rs @@ -322,17 +322,19 @@ export function getToken(): string | null { } fn generate_index(&self, registry: &SchemaRegistry) -> String { - let has_queries = registry + // Mirror the predicate used by `reactive::generate`: emit the re-export + // whenever the reactive file itself is emitted (queries OR mutations). + let has_reactive = registry .all_functions() .iter() - .any(|f| matches!(f.kind, FunctionKind::Query)); + .any(|f| matches!(f.kind, FunctionKind::Query | FunctionKind::Mutation)); let mut output = String::from("// @generated by FORGE - DO NOT EDIT\n"); output.push_str("export * from './types';\n"); output.push_str("export * from './api';\n"); output.push_str("export * from './stores';\n"); output.push_str("export * from './runes.svelte';\n"); - if has_queries { + if has_reactive { output.push_str("export * from './reactive.svelte';\n"); } if self.options.generate_auth_store { @@ -409,17 +411,32 @@ mod tests { } #[test] - fn generate_index_skips_reactive_when_only_mutations_present() { + fn generate_index_emits_reactive_when_only_mutations_present() { + // Reactive file is emitted for queries OR mutations, so the re-export + // must match. A mutation-only project still publishes `createX$`. let generator = TypeScriptGenerator::new("/tmp/forge"); let registry = SchemaRegistry::new(); registry.register_function(FunctionDef::mutation("create_user", RustType::String)); let index = generator.generate_index(®istry); assert!( - !index.contains("'./reactive.svelte'"), - "no queries => no reactive export" + index.contains("'./reactive.svelte'"), + "mutations must trigger reactive export" ); } + #[test] + fn generate_index_skips_reactive_when_no_queries_or_mutations() { + let generator = TypeScriptGenerator::new("/tmp/forge"); + let registry = SchemaRegistry::new(); + registry.register_function(FunctionDef::new( + "daily_cleanup", + FunctionKind::Cron, + RustType::Custom("()".into()), + )); + let index = generator.generate_index(®istry); + assert!(!index.contains("'./reactive.svelte'")); + } + #[test] fn generate_index_emits_auth_only_when_flag_set() { let registry = SchemaRegistry::new(); diff --git a/crates/forge-codegen/tests/snapshot.rs b/crates/forge-codegen/tests/snapshot.rs index c799411f..b6112281 100644 --- a/crates/forge-codegen/tests/snapshot.rs +++ b/crates/forge-codegen/tests/snapshot.rs @@ -78,6 +78,59 @@ fn output_is_deterministic() { assert_eq!(first, second, "codegen output must be deterministic"); } +/// Negative coverage: a `#[forge::model]` on a tuple struct must produce a +/// parser diagnostic with a clear message, not silently emit an empty +/// interface. A regression that drops the named-fields guard would otherwise +/// let downstream tests pass with an empty `Wrapper {}` interface. +#[test] +fn tuple_struct_model_is_rejected_with_clear_diagnostic() { + let src = r#" + #[forge::model] + pub struct Wrapper(pub String); + "#; + let src_dir = TempDir::new().expect("tempdir"); + fs::write(src_dir.path().join("handlers.rs"), src).expect("write fixture"); + + let outcome = parse_project(src_dir.path()).expect("parse_project"); + assert!( + !outcome.parse_failures.is_empty(), + "tuple struct must be rejected by the parser, got no failures" + ); + let (_, msg) = outcome + .parse_failures + .first() + .expect("at least one failure"); + assert!( + msg.contains("named fields"), + "diagnostic must explain the constraint, got: {msg}" + ); +} + +/// Same guard for unit structs marked as DTOs (serde derive). +#[test] +fn unit_struct_dto_is_rejected_with_clear_diagnostic() { + let src = r#" + #[derive(serde::Serialize, serde::Deserialize)] + pub struct Marker; + "#; + let src_dir = TempDir::new().expect("tempdir"); + fs::write(src_dir.path().join("handlers.rs"), src).expect("write fixture"); + + let outcome = parse_project(src_dir.path()).expect("parse_project"); + assert!( + !outcome.parse_failures.is_empty(), + "unit struct DTO must be rejected by the parser, got no failures" + ); + let (_, msg) = outcome + .parse_failures + .first() + .expect("at least one failure"); + assert!( + msg.contains("named fields"), + "diagnostic must explain the constraint, got: {msg}" + ); +} + fn run_fixture(source: &str, generate_auth: bool) -> String { let src_dir = TempDir::new().expect("tempdir"); fs::write(src_dir.path().join("handlers.rs"), source).expect("write fixture"); diff --git a/crates/forge-codegen/tests/snapshots/snapshot__custom_args.snap b/crates/forge-codegen/tests/snapshots/snapshot__custom_args.snap index ac244c3d..b3f82529 100644 --- a/crates/forge-codegen/tests/snapshots/snapshot__custom_args.snap +++ b/crates/forge-codegen/tests/snapshots/snapshot__custom_args.snap @@ -1,6 +1,5 @@ --- source: crates/forge-codegen/tests/snapshot.rs -assertion_line: 33 expression: "run_fixture(include_str!(\"fixtures/custom_args.rs.txt\"), false)" --- @@ -261,7 +260,7 @@ pub use types::*; === dioxus/types.rs === // @generated by FORGE - DO NOT EDIT -#![allow(dead_code, unused_imports, clippy::redundant_field_names, clippy::too_many_arguments)] +#![allow(dead_code, unused_imports, clippy::too_many_arguments)] use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] diff --git a/crates/forge-codegen/tests/snapshots/snapshot__full_app.snap b/crates/forge-codegen/tests/snapshots/snapshot__full_app.snap index 3aed3cf3..f386ed64 100644 --- a/crates/forge-codegen/tests/snapshots/snapshot__full_app.snap +++ b/crates/forge-codegen/tests/snapshots/snapshot__full_app.snap @@ -1,6 +1,5 @@ --- source: crates/forge-codegen/tests/snapshot.rs -assertion_line: 57 expression: "run_fixture(include_str!(\"fixtures/full_app.rs.txt\"), false)" --- @@ -325,7 +324,7 @@ pub use types::*; === dioxus/types.rs === // @generated by FORGE - DO NOT EDIT -#![allow(dead_code, unused_imports, clippy::redundant_field_names, clippy::too_many_arguments)] +#![allow(dead_code, unused_imports, clippy::too_many_arguments)] use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -336,7 +335,7 @@ pub struct CleanupOutput { impl CleanupOutput { pub fn new(deleted_count: i64) -> Self { Self { - deleted_count: deleted_count, + deleted_count, } } } @@ -349,7 +348,7 @@ pub struct CleanupRequest { impl CleanupRequest { pub fn new(older_than_days: i32) -> Self { Self { - older_than_days: older_than_days, + older_than_days, } } } @@ -364,7 +363,7 @@ impl CreateTodoInput { pub fn new(title: impl Into, status: TodoStatus) -> Self { Self { title: title.into(), - status: status, + status, } } } diff --git a/crates/forge-codegen/tests/snapshots/snapshot__full_app_with_auth.snap b/crates/forge-codegen/tests/snapshots/snapshot__full_app_with_auth.snap index a8dc95b3..d06848a6 100644 --- a/crates/forge-codegen/tests/snapshots/snapshot__full_app_with_auth.snap +++ b/crates/forge-codegen/tests/snapshots/snapshot__full_app_with_auth.snap @@ -1,6 +1,5 @@ --- source: crates/forge-codegen/tests/snapshot.rs -assertion_line: 65 expression: "run_fixture(include_str!(\"fixtures/full_app.rs.txt\"), true)" --- @@ -483,7 +482,7 @@ pub use types::*; === dioxus/types.rs === // @generated by FORGE - DO NOT EDIT -#![allow(dead_code, unused_imports, clippy::redundant_field_names, clippy::too_many_arguments)] +#![allow(dead_code, unused_imports, clippy::too_many_arguments)] use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -494,7 +493,7 @@ pub struct CleanupOutput { impl CleanupOutput { pub fn new(deleted_count: i64) -> Self { Self { - deleted_count: deleted_count, + deleted_count, } } } @@ -507,7 +506,7 @@ pub struct CleanupRequest { impl CleanupRequest { pub fn new(older_than_days: i32) -> Self { Self { - older_than_days: older_than_days, + older_than_days, } } } @@ -522,7 +521,7 @@ impl CreateTodoInput { pub fn new(title: impl Into, status: TodoStatus) -> Self { Self { title: title.into(), - status: status, + status, } } } diff --git a/crates/forge-codegen/tests/snapshots/snapshot__jobs_and_workflows.snap b/crates/forge-codegen/tests/snapshots/snapshot__jobs_and_workflows.snap index c364c76c..72ff4cd8 100644 --- a/crates/forge-codegen/tests/snapshots/snapshot__jobs_and_workflows.snap +++ b/crates/forge-codegen/tests/snapshots/snapshot__jobs_and_workflows.snap @@ -1,6 +1,5 @@ --- source: crates/forge-codegen/tests/snapshot.rs -assertion_line: 41 expression: "run_fixture(include_str!(\"fixtures/jobs_and_workflows.rs.txt\"), false)" --- @@ -215,7 +214,7 @@ pub use types::*; === dioxus/types.rs === // @generated by FORGE - DO NOT EDIT -#![allow(dead_code, unused_imports, clippy::redundant_field_names, clippy::too_many_arguments)] +#![allow(dead_code, unused_imports, clippy::too_many_arguments)] use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -228,7 +227,7 @@ impl ExportRequest { pub fn new(format: impl Into, include_archived: bool) -> Self { Self { format: format.into(), - include_archived: include_archived, + include_archived, } } } @@ -243,7 +242,7 @@ impl ExportResult { pub fn new(url: impl Into, row_count: i64) -> Self { Self { url: url.into(), - row_count: row_count, + row_count, } } } @@ -271,7 +270,7 @@ pub struct VerifyOutput { impl VerifyOutput { pub fn new(verified: bool) -> Self { Self { - verified: verified, + verified, } } } diff --git a/crates/forge-codegen/tests/snapshots/snapshot__models_and_enums.snap b/crates/forge-codegen/tests/snapshots/snapshot__models_and_enums.snap index 8aaa023a..791ce6c8 100644 --- a/crates/forge-codegen/tests/snapshots/snapshot__models_and_enums.snap +++ b/crates/forge-codegen/tests/snapshots/snapshot__models_and_enums.snap @@ -1,6 +1,5 @@ --- source: crates/forge-codegen/tests/snapshot.rs -assertion_line: 25 expression: "run_fixture(include_str!(\"fixtures/models_and_enums.rs.txt\"), false)" --- @@ -242,7 +241,7 @@ pub use types::*; === dioxus/types.rs === // @generated by FORGE - DO NOT EDIT -#![allow(dead_code, unused_imports, clippy::redundant_field_names, clippy::too_many_arguments)] +#![allow(dead_code, unused_imports, clippy::too_many_arguments)] use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -266,7 +265,7 @@ impl UserSummary { Self { id: id.into(), email: email.into(), - role: role, + role, } } } diff --git a/crates/forge-codegen/tests/snapshots/snapshot__primitives.snap b/crates/forge-codegen/tests/snapshots/snapshot__primitives.snap index 961b38e4..b9e75109 100644 --- a/crates/forge-codegen/tests/snapshots/snapshot__primitives.snap +++ b/crates/forge-codegen/tests/snapshots/snapshot__primitives.snap @@ -1,6 +1,5 @@ --- source: crates/forge-codegen/tests/snapshot.rs -assertion_line: 17 expression: "run_fixture(include_str!(\"fixtures/primitives.rs.txt\"), false)" --- @@ -332,6 +331,6 @@ pub use types::*; === dioxus/types.rs === // @generated by FORGE - DO NOT EDIT -#![allow(dead_code, unused_imports, clippy::redundant_field_names, clippy::too_many_arguments)] +#![allow(dead_code, unused_imports, clippy::too_many_arguments)] use serde::{Deserialize, Serialize}; diff --git a/crates/forge-codegen/tests/snapshots/snapshot__upload.snap b/crates/forge-codegen/tests/snapshots/snapshot__upload.snap index c66c8cda..fa354dc5 100644 --- a/crates/forge-codegen/tests/snapshots/snapshot__upload.snap +++ b/crates/forge-codegen/tests/snapshots/snapshot__upload.snap @@ -1,6 +1,5 @@ --- source: crates/forge-codegen/tests/snapshot.rs -assertion_line: 49 expression: "run_fixture(include_str!(\"fixtures/upload.rs.txt\"), false)" --- @@ -24,6 +23,7 @@ export * from './types'; export * from './api'; export * from './stores'; export * from './runes.svelte'; +export * from './reactive.svelte'; export { ForgeClient, ForgeClientError, createForgeClient, ForgeProvider } from '@forge-rs/svelte'; === ts/reactive.svelte.ts === @@ -215,14 +215,15 @@ pub use types::*; === dioxus/types.rs === // @generated by FORGE - DO NOT EDIT -#![allow(dead_code, unused_imports, clippy::redundant_field_names, clippy::too_many_arguments)] +#![allow(dead_code, unused_imports, clippy::too_many_arguments)] use serde::{Deserialize, Serialize}; use forge_dioxus::ForgeUpload; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct AvatarInput { pub name: String, + #[serde(skip)] pub file: ForgeUpload, } @@ -230,7 +231,7 @@ impl AvatarInput { pub fn new(name: impl Into, file: ForgeUpload) -> Self { Self { name: name.into(), - file: file, + file, } } } diff --git a/crates/forge-core/Cargo.toml b/crates/forge-core/Cargo.toml index 8bd28479..b566528b 100644 --- a/crates/forge-core/Cargo.toml +++ b/crates/forge-core/Cargo.toml @@ -34,6 +34,10 @@ testcontainers-modules = { workspace = true, optional = true } [features] testcontainers = ["dep:testcontainers", "dep:testcontainers-modules"] +# Unsafe escape hatches that bypass framework guard rails (circuit breaker, +# host blocklists, etc). Opt-in only — keep this off unless a specific +# integration genuinely needs raw access. +escape-hatches = [] [dev-dependencies] tokio-test = { workspace = true } diff --git a/crates/forge-core/src/config/auth.rs b/crates/forge-core/src/config/auth.rs index 3298b2a2..03214494 100644 --- a/crates/forge-core/src/config/auth.rs +++ b/crates/forge-core/src/config/auth.rs @@ -30,7 +30,7 @@ pub enum JwtAlgorithm { /// Rotate by adding the outgoing secret here with `valid_until` set one /// access-token TTL into the future, swap `jwt_secret` to the new value, /// then remove the entry once the window closes. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize)] pub struct LegacySecret { /// HMAC secret bytes (treated as opaque; min length is not re-enforced /// here — the active `jwt_secret` validation already covers minimum @@ -40,8 +40,17 @@ pub struct LegacySecret { pub valid_until: chrono::DateTime, } +impl std::fmt::Debug for LegacySecret { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("LegacySecret") + .field("secret", &"***redacted***") + .field("valid_until", &self.valid_until) + .finish() + } +} + /// Authentication configuration. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize)] #[non_exhaustive] pub struct AuthConfig { /// Required for HS256. @@ -106,6 +115,45 @@ pub struct AuthConfig { /// are silently dropped at middleware construction. #[serde(default)] pub legacy_secrets: Vec, + + /// When `true` (default), browser clients (forge-svelte, forge-dioxus on + /// wasm) treat the refresh token as an `HttpOnly; Secure; SameSite=Strict` + /// cookie and do **not** persist it in JS-reachable storage. Your + /// `refresh` mutation should set the cookie on issue and clear it on + /// rotation/logout; the clients send it automatically via `credentials: + /// include`. + /// + /// Set to `false` only if you cannot serve the refresh endpoint from the + /// same registrable domain as the frontend, or for legacy clients that + /// must read the refresh token from a response body. + #[serde(default = "default_true")] + pub refresh_cookie: bool, +} + +impl std::fmt::Debug for AuthConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AuthConfig") + .field( + "jwt_secret", + &self.jwt_secret.as_ref().map(|_| "***redacted***"), + ) + .field("jwt_algorithm", &self.jwt_algorithm) + .field("jwt_issuer", &self.jwt_issuer) + .field("jwt_audience", &self.jwt_audience) + .field("access_token_ttl", &self.access_token_ttl) + .field("refresh_token_ttl", &self.refresh_token_ttl) + .field("jwks_url", &self.jwks_url) + .field("jwks_cache_ttl", &self.jwks_cache_ttl) + .field("session_ttl", &self.session_ttl) + .field("jwt_leeway", &self.jwt_leeway) + .field("audience_required", &self.audience_required) + .field("required_claims", &self.required_claims) + .field("session_cookie_ttl", &self.session_cookie_ttl) + .field("jwks_require_kid", &self.jwks_require_kid) + .field("legacy_secrets", &self.legacy_secrets) + .field("refresh_cookie", &self.refresh_cookie) + .finish() + } } impl Default for AuthConfig { @@ -126,6 +174,7 @@ impl Default for AuthConfig { session_cookie_ttl: None, jwks_require_kid: default_true(), legacy_secrets: Vec::new(), + refresh_cookie: true, } } } @@ -190,12 +239,24 @@ impl AuthConfig { } } JwtAlgorithm::RS256 => { - if self.jwks_url.is_none() { + let Some(url) = self.jwks_url.as_deref() else { return Err(ForgeError::config( "auth.jwks_url is required for RSA algorithms (RS256). \ Set auth.jwks_url to your identity provider's JWKS endpoint, \ or switch to HS256 and provide auth.jwt_secret for symmetric signing.", )); + }; + // Plain HTTP would let an on-path attacker substitute keys and + // mint arbitrary RS256 tokens. Loopback is allowed for local + // dev so test mocks don't need TLS termination. + if let Some(hostname) = crate::util::http_hostname(url) + && !crate::util::is_loopback_host(hostname) + { + return Err(ForgeError::config(format!( + "auth.jwks_url '{url}' uses plain HTTP. JWKS must be fetched over \ + HTTPS (or from loopback for local development) so an on-path \ + attacker cannot substitute signing keys." + ))); } } } diff --git a/crates/forge-core/src/config/database.rs b/crates/forge-core/src/config/database.rs index f808703f..2e65af18 100644 --- a/crates/forge-core/src/config/database.rs +++ b/crates/forge-core/src/config/database.rs @@ -11,7 +11,7 @@ use super::types::DurationStr; /// separation belongs at the worker level, not the connection level. The /// single-pool contention model and sizing formula are documented at the /// runtime side in `forge_runtime::pg::pool` module docs. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize)] #[serde(deny_unknown_fields)] #[non_exhaustive] pub struct DatabaseConfig { @@ -56,6 +56,24 @@ pub struct DatabaseConfig { pub test_before_acquire: bool, } +impl std::fmt::Debug for DatabaseConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let redacted_replicas: Vec<&str> = + self.replica_urls.iter().map(|_| "***redacted***").collect(); + f.debug_struct("DatabaseConfig") + .field("url", &"***redacted***") + .field("pool_size", &self.pool_size) + .field("pool_timeout", &self.pool_timeout) + .field("statement_timeout", &self.statement_timeout) + .field("replica_urls", &redacted_replicas) + .field("read_from_replica", &self.read_from_replica) + .field("replica_pool_size", &self.replica_pool_size) + .field("min_pool_size", &self.min_pool_size) + .field("test_before_acquire", &self.test_before_acquire) + .finish() + } +} + impl Default for DatabaseConfig { fn default() -> Self { Self { diff --git a/crates/forge-core/src/config/loader.rs b/crates/forge-core/src/config/loader.rs index e67428b7..a5571679 100644 --- a/crates/forge-core/src/config/loader.rs +++ b/crates/forge-core/src/config/loader.rs @@ -48,13 +48,19 @@ pub fn substitute_env_vars(content: &str) -> String { /// Parse `VAR-default` or `VAR:-default` into (name, optional default). /// Both forms behave identically (fallback when unset). `:-` is checked /// first so its `-` doesn't get matched by the plain `-` branch. +/// +/// For the bare `-` form, the split is taken at the LAST `-` so that +/// `${MY-NAMESPACE-VAR-fallback}` parses to name `MY-NAMESPACE-VAR` +/// (which then fails `is_valid_env_var_name` and the literal is +/// preserved) rather than silently substituting `$MY` with default +/// `NAMESPACE-VAR-fallback`. #[allow(clippy::indexing_slicing)] // All indices from str::find(); guaranteed valid. fn parse_var_with_default(inner: &str) -> (&str, Option<&str>) { if let Some(pos) = inner.find(":-") { return (&inner[..pos], Some(&inner[pos + 2..])); } - if let Some(pos) = inner.find('-') { - return (&inner[..pos], Some(&inner[pos + 1..])); + if let Some((name, default)) = inner.rsplit_once('-') { + return (name, Some(default)); } (inner, None) } @@ -121,4 +127,96 @@ mod tests { let result = substitute_env_vars(input); assert_eq!(result, r#"val = """#); } + + #[test] + fn plain_braced_var_substituted_when_set() { + // `${VAR}` with no default, variable present -> raw value. + unsafe { std::env::set_var("TEST_FORGE_PLAIN_SET", "postgres://db") }; + + let input = r#"url = "${TEST_FORGE_PLAIN_SET}""#; + let result = substitute_env_vars(input); + assert_eq!(result, r#"url = "postgres://db""#); + + unsafe { std::env::remove_var("TEST_FORGE_PLAIN_SET") }; + } + + #[test] + fn set_var_wins_over_dash_default() { + unsafe { std::env::set_var("TEST_FORGE_DASH_SET", "real") }; + + let input = r#"x = "${TEST_FORGE_DASH_SET-fallback}""#; + let result = substitute_env_vars(input); + assert_eq!(result, r#"x = "real""#); + + unsafe { std::env::remove_var("TEST_FORGE_DASH_SET") }; + } + + #[test] + fn dash_split_takes_last_dash() { + // `parse_var_with_default` splits on the LAST `-`, so the name here is + // "TEST_FORGE_NS_VAR" and the default is "tail". Name is valid and unset, + // so the default wins. + unsafe { std::env::remove_var("TEST_FORGE_NS_VAR") }; + + let input = r#"v = "${TEST_FORGE_NS_VAR-tail}""#; + let result = substitute_env_vars(input); + assert_eq!(result, r#"v = "tail""#); + } + + #[test] + fn invalid_var_name_from_multi_dash_preserves_literal() { + // Last-dash split yields name "MY-NAMESPACE-VAR", which fails + // `is_valid_env_var_name` (contains '-'). The whole `${...}` is kept + // verbatim rather than silently substituting a partial match. + unsafe { std::env::remove_var("MY") }; + + let input = r#"v = "${MY-NAMESPACE-VAR-fallback}""#; + let result = substitute_env_vars(input); + assert_eq!(result, r#"v = "${MY-NAMESPACE-VAR-fallback}""#); + } + + #[test] + fn lowercase_var_name_is_invalid_and_preserved() { + // Env var names must be uppercase/underscore-led; a lowercase name is + // not treated as a variable. + let input = r#"v = "${lowercase}""#; + let result = substitute_env_vars(input); + assert_eq!(result, r#"v = "${lowercase}""#); + } + + #[test] + fn unterminated_brace_kept_verbatim() { + // No closing `}` -> the remainder is emitted as-is, no panic. + let input = r#"v = "${TEST_FORGE_UNTERMINATED"#; + let result = substitute_env_vars(input); + assert_eq!(result, r#"v = "${TEST_FORGE_UNTERMINATED"#); + } + + #[test] + fn colon_dash_split_is_preferred_over_plain_dash() { + // `:-` is checked before plain `-`, so the name is the part before `:-` + // and the `-` inside the default is left intact. + unsafe { std::env::remove_var("TEST_FORGE_CDASH") }; + + let input = r#"v = "${TEST_FORGE_CDASH:-a-b-c}""#; + let result = substitute_env_vars(input); + assert_eq!(result, r#"v = "a-b-c""#); + } + + #[test] + fn parse_var_with_default_forms() { + assert_eq!(parse_var_with_default("VAR"), ("VAR", None)); + assert_eq!( + parse_var_with_default("VAR-default"), + ("VAR", Some("default")) + ); + assert_eq!( + parse_var_with_default("VAR:-default"), + ("VAR", Some("default")) + ); + // Last-dash split. + assert_eq!(parse_var_with_default("A-B-C"), ("A-B", Some("C"))); + // Colon-dash beats plain dash and keeps trailing dashes in the default. + assert_eq!(parse_var_with_default("V:-a-b"), ("V", Some("a-b"))); + } } diff --git a/crates/forge-core/src/config/mod.rs b/crates/forge-core/src/config/mod.rs index 3b1eda99..8c6bc450 100644 --- a/crates/forge-core/src/config/mod.rs +++ b/crates/forge-core/src/config/mod.rs @@ -104,9 +104,6 @@ pub struct ForgeConfig { #[serde(default)] pub realtime: RealtimeConfig, - - #[serde(default)] - pub email: crate::email::EmailConfig, } impl ForgeConfig { @@ -218,6 +215,39 @@ impl ForgeConfig { ))); } + let ratio = self.observability.sampling_ratio; + if !ratio.is_finite() || !(0.0..=1.0).contains(&ratio) { + return Err(ForgeError::config(format!( + "observability.sampling_ratio must be a finite number in [0.0, 1.0], got {ratio}" + ))); + } + + if let Some(path) = &self.signals.geoip_db_path + && !path.is_empty() + { + let p = std::path::Path::new(path); + if !p.exists() { + return Err(ForgeError::config(format!( + "signals.geoip_db_path points to '{path}' which does not exist" + ))); + } + if std::fs::File::open(p).is_err() { + return Err(ForgeError::config(format!( + "signals.geoip_db_path '{path}' exists but is not readable" + ))); + } + } + + if self.gateway.cors_enabled + && self.gateway.cors_origins.iter().any(|o| o == "*") + && self.gateway.cors_origins.len() == 1 + { + tracing::warn!( + "gateway.cors_origins = [\"*\"] allows any origin; browsers reject \ + wildcard with credentialed requests. Set explicit origins for production." + ); + } + for entry in &self.gateway.trusted_proxies { if entry.parse::().is_err() && entry.parse::().is_err() { @@ -251,7 +281,6 @@ impl ForgeConfig { signals: SignalsConfig::default(), rate_limit: RateLimitSettings::default(), realtime: RealtimeConfig::default(), - email: crate::email::EmailConfig::default(), } } } @@ -821,6 +850,142 @@ mod tests { assert_eq!(entry.valid_until.to_rfc3339(), "2099-01-01T00:00:00+00:00"); } + #[test] + fn validate_rejects_invalid_trusted_proxy_entry() { + let toml = r#" + [database] + url = "postgres://localhost/test" + [gateway] + trusted_proxies = ["not-an-ip"] + "#; + let err = ForgeConfig::parse_toml(toml).unwrap_err().to_string(); + assert!( + err.contains("trusted_proxies") && err.contains("not-an-ip"), + "expected trusted_proxies rejection, got: {err}" + ); + } + + #[test] + fn validate_accepts_ip_and_cidr_trusted_proxies() { + let toml = r#" + [database] + url = "postgres://localhost/test" + [gateway] + trusted_proxies = ["10.0.0.1", "10.0.0.0/8", "::1", "fd00::/8"] + "#; + assert!(ForgeConfig::parse_toml(toml).is_ok()); + } + + #[test] + fn validate_rejects_sampling_ratio_above_one() { + let toml = r#" + [database] + url = "postgres://localhost/test" + [observability] + sampling_ratio = 1.5 + "#; + let err = ForgeConfig::parse_toml(toml).unwrap_err().to_string(); + assert!( + err.contains("sampling_ratio") && err.contains("[0.0, 1.0]"), + "expected sampling_ratio bound error, got: {err}" + ); + } + + #[test] + fn validate_rejects_negative_sampling_ratio() { + let toml = r#" + [database] + url = "postgres://localhost/test" + [observability] + sampling_ratio = -0.1 + "#; + let err = ForgeConfig::parse_toml(toml).unwrap_err().to_string(); + assert!( + err.contains("sampling_ratio"), + "expected sampling_ratio bound error, got: {err}" + ); + } + + #[test] + fn validate_accepts_sampling_ratio_boundaries() { + for ratio in ["0.0", "1.0", "0.5"] { + let toml = format!( + r#" + [database] + url = "postgres://localhost/test" + [observability] + sampling_ratio = {ratio} + "# + ); + assert!( + ForgeConfig::parse_toml(&toml).is_ok(), + "ratio {ratio} should validate" + ); + } + } + + #[test] + fn validate_rejects_debounce_quiet_window_exceeding_max_wait() { + let toml = r#" + [database] + url = "postgres://localhost/test" + [realtime] + debounce_quiet_window = "500ms" + debounce_max_wait = "200ms" + "#; + let err = ForgeConfig::parse_toml(toml).unwrap_err().to_string(); + assert!( + err.contains("debounce_quiet_window") && err.contains("debounce_max_wait"), + "expected debounce ordering error, got: {err}" + ); + } + + #[test] + fn validate_accepts_debounce_quiet_window_equal_to_max_wait() { + // quiet == max is allowed; only quiet > max is rejected. + let toml = r#" + [database] + url = "postgres://localhost/test" + [realtime] + debounce_quiet_window = "200ms" + debounce_max_wait = "200ms" + "#; + assert!(ForgeConfig::parse_toml(toml).is_ok()); + } + + #[test] + fn validate_rejects_cors_origin_without_scheme() { + let toml = r#" + [database] + url = "postgres://localhost/test" + [gateway] + cors_enabled = true + cors_origins = ["example.com"] + "#; + let err = ForgeConfig::parse_toml(toml).unwrap_err().to_string(); + assert!( + err.contains("http://") && err.contains("https://"), + "expected scheme-required error, got: {err}" + ); + } + + #[test] + fn validate_rejects_cors_origin_with_control_char() { + // A control char in an origin would corrupt the response header. + let toml = " + [database] + url = \"postgres://localhost/test\" + [gateway] + cors_enabled = true + cors_origins = [\"https://exa\tmple.com\"] + "; + let err = ForgeConfig::parse_toml(toml).unwrap_err().to_string(); + assert!( + err.contains("invalid origin") || err.contains("valid HTTP header"), + "expected invalid-origin error, got: {err}" + ); + } + #[test] fn realtime_quota_fields_parse_and_enforce() { let toml = r#" diff --git a/crates/forge-core/src/config/security.rs b/crates/forge-core/src/config/security.rs index a7930b22..0460402b 100644 --- a/crates/forge-core/src/config/security.rs +++ b/crates/forge-core/src/config/security.rs @@ -3,8 +3,19 @@ use serde::{Deserialize, Serialize}; /// Security configuration. -#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[derive(Clone, Serialize, Deserialize, Default)] #[non_exhaustive] pub struct SecurityConfig { pub secret_key: Option, } + +impl std::fmt::Debug for SecurityConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SecurityConfig") + .field( + "secret_key", + &self.secret_key.as_ref().map(|_| "***redacted***"), + ) + .finish() + } +} diff --git a/crates/forge-core/src/config/signals.rs b/crates/forge-core/src/config/signals.rs index 980158a0..8954c7b7 100644 --- a/crates/forge-core/src/config/signals.rs +++ b/crates/forge-core/src/config/signals.rs @@ -7,6 +7,10 @@ use serde::{Deserialize, Serialize}; use super::default_true; use super::types::DurationStr; +fn default_false() -> bool { + false +} + /// Signals configuration for built-in product analytics and frontend diagnostics. /// /// Captures user behavior, acquisition channels, feature usage, and frontend @@ -15,7 +19,10 @@ use super::types::DurationStr; #[non_exhaustive] pub struct SignalsConfig { /// Enable the signals pipeline (event ingestion, auto-capture, dashboards). - #[serde(default = "default_true")] + /// + /// Off by default so new projects ship without product analytics enabled. + /// Set `signals.enabled = true` in forge.toml to opt in. + #[serde(default = "default_false")] pub enabled: bool, /// Auto-capture RPC calls as events without user code. @@ -68,12 +75,19 @@ pub struct SignalsConfig { /// database provides country-level resolution with zero configuration. #[serde(default)] pub geoip_db_path: Option, + + /// Per-IP request ceiling for the `/signal` endpoint, measured over a + /// rolling 60-second window. Generous enough to absorb legitimate bursts + /// (page view + web-vital flush + a handful of tracked events on a + /// navigation) while still capping runaway clients. + #[serde(default = "default_rate_limit_per_minute")] + pub rate_limit_per_minute: u32, } impl Default for SignalsConfig { fn default() -> Self { Self { - enabled: true, + enabled: false, auto_capture: true, diagnostics: true, session_timeout: default_session_timeout(), @@ -85,10 +99,15 @@ impl Default for SignalsConfig { excluded_functions: Vec::new(), bot_detection: true, geoip_db_path: None, + rate_limit_per_minute: default_rate_limit_per_minute(), } } } +fn default_rate_limit_per_minute() -> u32 { + 600 +} + fn default_session_timeout() -> DurationStr { DurationStr::new(Duration::from_secs(1800)) } @@ -117,7 +136,7 @@ mod tests { #[tokio::test] async fn default_config_has_correct_values() { let config = SignalsConfig::default(); - assert!(config.enabled); + assert!(!config.enabled); assert!(config.auto_capture); assert!(config.diagnostics); assert_eq!(config.session_timeout.as_secs(), 1800); @@ -141,7 +160,7 @@ mod tests { let from_table: Wrapper = toml::from_str("[signals]").unwrap(); for config in [from_empty, from_table.signals] { - assert!(config.enabled); + assert!(!config.enabled); assert!(config.auto_capture); assert!(config.diagnostics); assert_eq!(config.session_timeout.as_secs(), 1800); diff --git a/crates/forge-core/src/context.rs b/crates/forge-core/src/context.rs index aea72367..8f813a39 100644 --- a/crates/forge-core/src/context.rs +++ b/crates/forge-core/src/context.rs @@ -87,7 +87,7 @@ impl HandlerContext for crate::function::MutationContext { // MutationContext::tx() returns DbConn, not ForgeDb. // For HandlerContext we expose the pool-backed ForgeDb view, which // intentionally bypasses the active transaction. - crate::function::ForgeDb::from_pool(self.bypass_pool()) + crate::function::ForgeDb::from_pool(self.pool_outside_transaction()) } fn db_conn(&self) -> DbConn<'_> { diff --git a/crates/forge-core/src/cron/schedule.rs b/crates/forge-core/src/cron/schedule.rs index b833ef8d..c9ff8d70 100644 --- a/crates/forge-core/src/cron/schedule.rs +++ b/crates/forge-core/src/cron/schedule.rs @@ -35,6 +35,15 @@ impl CronSchedule { }) } + /// Validate a timezone string at registration time. Returns an error when + /// the timezone is not recognised so misconfigured crons fail loudly at + /// startup instead of silently never firing. + pub fn validate_timezone(timezone: &str) -> Result<(), CronParseError> { + timezone.parse::().map(|_| ()).map_err(|e| { + CronParseError::InvalidExpression(format!("invalid timezone '{timezone}': {e}")) + }) + } + /// Create a cron schedule from an expression that was already validated at compile time. /// /// Falls back to a non-firing schedule if parsing somehow fails, which cannot happen @@ -96,12 +105,17 @@ impl CronSchedule { return vec![]; }; - let local_start = start.with_timezone(&tz); + // `cron::Schedule::after` is exclusive of the boundary. Subtract one + // second so a scheduled tick that lands exactly on `start` is still + // emitted — otherwise a 1 s scheduler poll can drop a tick whose + // moment coincides with the window edge. + let local_start = start.with_timezone(&tz) - chrono::Duration::seconds(1); let local_end = end.with_timezone(&tz); schedule .after(&local_start) .take_while(|dt| *dt <= local_end) + .filter(|dt| *dt >= start.with_timezone(&tz)) .map(|dt| dt.with_timezone(&Utc)) .collect() } diff --git a/crates/forge-core/src/email/mod.rs b/crates/forge-core/src/email/mod.rs deleted file mode 100644 index 3b4018ab..00000000 --- a/crates/forge-core/src/email/mod.rs +++ /dev/null @@ -1,158 +0,0 @@ -//! Email sending trait and types. -//! -//! Defines the `EmailSender` trait used by handler contexts via `ctx.email()`. -//! The runtime provides concrete implementations (SMTP, HTTP-based providers). - -use std::future::Future; -use std::pin::Pin; - -use crate::error::Result; - -/// An email message. -#[derive(Debug, Clone)] -pub struct Email { - /// Overrides the default `from` in config if set. - pub from: Option, - pub to: Vec, - pub cc: Vec, - pub bcc: Vec, - pub subject: String, - pub text: Option, - pub html: Option, - pub reply_to: Option, -} - -impl Email { - /// Create a new email to a single recipient. - pub fn to(recipient: impl Into) -> EmailBuilder { - EmailBuilder { - email: Self { - from: None, - to: vec![recipient.into()], - cc: Vec::new(), - bcc: Vec::new(), - subject: String::new(), - text: None, - html: None, - reply_to: None, - }, - } - } -} - -/// Builder for constructing email messages. -pub struct EmailBuilder { - email: Email, -} - -impl EmailBuilder { - pub fn to(mut self, recipient: impl Into) -> Self { - self.email.to.push(recipient.into()); - self - } - - pub fn from(mut self, sender: impl Into) -> Self { - self.email.from = Some(sender.into()); - self - } - - pub fn cc(mut self, recipient: impl Into) -> Self { - self.email.cc.push(recipient.into()); - self - } - - pub fn bcc(mut self, recipient: impl Into) -> Self { - self.email.bcc.push(recipient.into()); - self - } - - pub fn subject(mut self, subject: impl Into) -> Self { - self.email.subject = subject.into(); - self - } - - pub fn text(mut self, body: impl Into) -> Self { - self.email.text = Some(body.into()); - self - } - - pub fn html(mut self, body: impl Into) -> Self { - self.email.html = Some(body.into()); - self - } - - pub fn reply_to(mut self, address: impl Into) -> Self { - self.email.reply_to = Some(address.into()); - self - } - - pub fn build(self) -> Email { - self.email - } -} - -/// Trait for sending emails from handler contexts. -/// -/// Implemented by the runtime for SMTP and HTTP-based providers (Resend, SES). -/// Mocked in test contexts. -pub trait EmailSender: Send + Sync + 'static { - /// Send an email. Returns the provider's message ID on success. - fn send<'a>( - &'a self, - email: &'a Email, - ) -> Pin> + Send + 'a>>; -} - -/// Email configuration from forge.toml. -#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)] -#[serde(default)] -pub struct EmailConfig { - pub enabled: bool, - /// Provider: "smtp", "resend", "ses", "log" (development). - pub provider: String, - /// Default sender address. - pub from: String, - pub smtp_host: Option, - /// Default 587. - pub smtp_port: Option, - /// Env var containing the API key or SMTP password. - pub secret_env: Option, -} - -#[cfg(test)] -#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)] -mod tests { - use super::*; - - #[test] - fn email_builder_creates_message() { - let email = Email::to("user@example.com") - .from("noreply@app.com") - .subject("Hello") - .text("Hi there") - .html("

Hi there

") - .cc("cc@example.com") - .bcc("bcc@example.com") - .reply_to("reply@app.com") - .build(); - - assert_eq!(email.to, vec!["user@example.com"]); - assert_eq!(email.from.as_deref(), Some("noreply@app.com")); - assert_eq!(email.subject, "Hello"); - assert_eq!(email.text.as_deref(), Some("Hi there")); - assert_eq!(email.html.as_deref(), Some("

Hi there

")); - assert_eq!(email.cc, vec!["cc@example.com"]); - assert_eq!(email.bcc, vec!["bcc@example.com"]); - assert_eq!(email.reply_to.as_deref(), Some("reply@app.com")); - } - - #[test] - fn email_builder_multiple_recipients() { - let email = Email::to("a@example.com") - .to("b@example.com") - .subject("Test") - .build(); - - assert_eq!(email.to.len(), 2); - } -} diff --git a/crates/forge-core/src/error.rs b/crates/forge-core/src/error.rs index e095fc4e..f755cc63 100644 --- a/crates/forge-core/src/error.rs +++ b/crates/forge-core/src/error.rs @@ -15,8 +15,17 @@ pub enum ForgeError { source: Option>, }, - #[error("Database error: {0}")] - Database(#[from] sqlx::Error), + /// Wraps the inner [`sqlx::Error`] without rendering it in `Display`. + /// The raw sqlx error (which may contain constraint names, schema names, + /// or bound parameter previews) is reachable via [`std::error::Error::source`] + /// for structured logging, but the public `Display` impl emits a generic + /// "database error" so it is safe to surface in API responses. + #[error("database error")] + Database( + #[source] + #[from] + sqlx::Error, + ), #[error("Job cancelled: {0}")] JobCancelled(String), @@ -158,10 +167,31 @@ impl ForgeError { } pub fn is_retryable(&self) -> bool { - matches!( - self, - Self::ServiceUnavailable(_) | Self::Timeout(_) | Self::RateLimitExceeded { .. } - ) + match self { + Self::ServiceUnavailable(_) | Self::Timeout(_) | Self::RateLimitExceeded { .. } => true, + Self::Database(err) => is_transient_sqlx_error(err), + _ => false, + } + } +} + +/// Heuristic for sqlx errors that are safe-to-retry transient failures: +/// pool checkout timeouts, dropped or closed connections, and IO errors +/// against the database socket. Logical errors (constraint violations, +/// type mismatches, missing rows) intentionally do not retry. +fn is_transient_sqlx_error(err: &sqlx::Error) -> bool { + match err { + sqlx::Error::PoolTimedOut | sqlx::Error::PoolClosed | sqlx::Error::WorkerCrashed => true, + sqlx::Error::Io(_) => true, + sqlx::Error::Database(db_err) => { + // PostgreSQL connection_exception family (08xxx) and + // statement_timeout (57014) are transient. + db_err + .code() + .map(|c| c.starts_with("08") || c == "57014" || c == "57P03") + .unwrap_or(false) + } + _ => false, } } @@ -180,10 +210,12 @@ impl From for ForgeError { crate::http::CircuitBreakerError::Request(err) if err.is_timeout() => { ForgeError::Timeout(err.to_string()) } - crate::http::CircuitBreakerError::Request(err) => ForgeError::Internal { - context: "HTTP request failed".to_string(), - source: Some(Box::new(err)), - }, + crate::http::CircuitBreakerError::Request(err) => { + // Non-timeout reqwest failures (connection refused, DNS, + // TLS) are upstream-side problems, not local bugs. Map to + // 503 so clients understand it's worth retrying. + ForgeError::ServiceUnavailable(format!("HTTP request failed: {err}")) + } crate::http::CircuitBreakerError::PrivateHostBlocked(host) => { ForgeError::Forbidden(format!("Outbound request to private host '{host}' blocked")) } @@ -210,7 +242,7 @@ mod tests { ), ( ForgeError::Database(sqlx::Error::RowNotFound), - "Database error: no rows returned by a query that expected to return at least one row", + "database error", ), ( ForgeError::JobCancelled("user request".into()), @@ -438,4 +470,143 @@ mod tests { assert_eq!(err.to_string(), "Internal error: connection failed"); assert!(err.source().is_some(), "source should be preserved"); } + + /// Minimal `sqlx::error::DatabaseError` carrying a fixed SQLSTATE code so we + /// can drive `is_transient_sqlx_error`'s `Database` arm without a live PG. + #[derive(Debug)] + struct FakeDbError { + code: Option, + unique: bool, + } + + impl FakeDbError { + fn with_code(code: &str) -> Self { + Self { + code: Some(code.to_string()), + unique: false, + } + } + + fn unique_violation() -> Self { + // 23505 is PG's unique_violation. A logical constraint failure must + // not be treated as transient. + Self { + code: Some("23505".to_string()), + unique: true, + } + } + + fn no_code() -> Self { + Self { + code: None, + unique: false, + } + } + } + + impl std::fmt::Display for FakeDbError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "fake db error ({:?})", self.code) + } + } + + impl std::error::Error for FakeDbError {} + + impl sqlx::error::DatabaseError for FakeDbError { + fn message(&self) -> &str { + "fake db error" + } + + fn code(&self) -> Option> { + self.code.as_deref().map(std::borrow::Cow::Borrowed) + } + + fn as_error(&self) -> &(dyn std::error::Error + Send + Sync + 'static) { + self + } + + fn as_error_mut(&mut self) -> &mut (dyn std::error::Error + Send + Sync + 'static) { + self + } + + fn into_error(self: Box) -> Box { + self + } + + fn kind(&self) -> sqlx::error::ErrorKind { + if self.unique { + sqlx::error::ErrorKind::UniqueViolation + } else { + sqlx::error::ErrorKind::Other + } + } + } + + fn db(err: FakeDbError) -> sqlx::Error { + sqlx::Error::Database(Box::new(err)) + } + + #[test] + fn transient_sqlx_pool_and_worker_errors_retry() { + assert!(is_transient_sqlx_error(&sqlx::Error::PoolTimedOut)); + assert!(is_transient_sqlx_error(&sqlx::Error::PoolClosed)); + assert!(is_transient_sqlx_error(&sqlx::Error::WorkerCrashed)); + } + + #[test] + fn transient_sqlx_io_error_retries() { + let io = std::io::Error::new(std::io::ErrorKind::ConnectionReset, "reset"); + assert!(is_transient_sqlx_error(&sqlx::Error::Io(io))); + } + + #[test] + fn transient_sqlx_connection_family_08xxx_retries() { + // 08006 connection_failure, 08003 connection_does_not_exist, etc. + assert!(is_transient_sqlx_error(&db(FakeDbError::with_code( + "08006" + )))); + assert!(is_transient_sqlx_error(&db(FakeDbError::with_code( + "08003" + )))); + assert!(is_transient_sqlx_error(&db(FakeDbError::with_code( + "08000" + )))); + } + + #[test] + fn transient_sqlx_statement_timeout_and_admin_shutdown_retry() { + // 57014 query_canceled (statement_timeout), 57P03 cannot_connect_now. + assert!(is_transient_sqlx_error(&db(FakeDbError::with_code( + "57014" + )))); + assert!(is_transient_sqlx_error(&db(FakeDbError::with_code( + "57P03" + )))); + } + + #[test] + fn non_transient_sqlx_logical_errors_do_not_retry() { + // Constraint violation, row-not-found, and a missing code are all + // logical/non-retryable. + assert!(!is_transient_sqlx_error(&db( + FakeDbError::unique_violation() + ))); + assert!(!is_transient_sqlx_error(&db(FakeDbError::with_code( + "23503" + )))); + assert!(!is_transient_sqlx_error(&db(FakeDbError::no_code()))); + assert!(!is_transient_sqlx_error(&sqlx::Error::RowNotFound)); + // 57 family that isn't a retry code (e.g. 57000 operator_intervention). + assert!(!is_transient_sqlx_error(&db(FakeDbError::with_code( + "57000" + )))); + } + + #[test] + fn is_retryable_database_delegates_to_transient_check() { + assert!(ForgeError::Database(db(FakeDbError::with_code("08006"))).is_retryable()); + assert!(ForgeError::Database(db(FakeDbError::with_code("57014"))).is_retryable()); + assert!(!ForgeError::Database(db(FakeDbError::unique_violation())).is_retryable()); + assert!(!ForgeError::Database(sqlx::Error::RowNotFound).is_retryable()); + } } diff --git a/crates/forge-core/src/function/context.rs b/crates/forge-core/src/function/context.rs index 9da60a5c..1f881ce0 100644 --- a/crates/forge-core/src/function/context.rs +++ b/crates/forge-core/src/function/context.rs @@ -53,6 +53,11 @@ use crate::auth::Claims; use crate::env::{EnvAccess, EnvProvider, RealEnvProvider}; use crate::http::CircuitBreakerClient; +/// Default outbound HTTP timeout applied by [`MutationContext::http`] when +/// no per-handler `timeout` is configured. Keeps a misbehaving downstream +/// from hanging an RPC indefinitely. +pub const DEFAULT_HTTP_TIMEOUT: Duration = Duration::from_secs(30); + /// Token issuer for signing JWTs. /// /// Implemented by the runtime when HMAC auth is configured. @@ -420,6 +425,13 @@ impl<'c> sqlx::Executor<'c> for &'c mut ForgeConn<'_> { } /// Authentication context available to all functions. +/// +/// KNOWN ISSUE: `authenticated` and `user_id` encode overlapping state — +/// an authenticated subject without a UUID (Firebase, Clerk) is represented +/// as `authenticated = true` with `user_id = None`. Constructors are the +/// only places that set these; each one preserves the invariant +/// `authenticated == (user_id.is_some() || claims.contains_key("sub"))`. +/// Collapsing into a single sum type is tracked for a future cleanup. #[derive(Debug, Clone)] #[non_exhaustive] pub struct AuthContext { @@ -434,13 +446,15 @@ pub struct AuthContext { impl AuthContext { /// Create an unauthenticated context. pub fn unauthenticated() -> Self { - Self { + let ctx = Self { user_id: None, roles: Vec::new(), claims: HashMap::new(), authenticated: false, token_exp: None, - } + }; + debug_assert!(!ctx.authenticated && ctx.user_id.is_none()); + ctx } /// Create an authenticated context with a UUID user ID. @@ -449,13 +463,15 @@ impl AuthContext { roles: Vec, claims: HashMap, ) -> Self { - Self { + let ctx = Self { user_id: Some(user_id), roles, claims, authenticated: true, token_exp: None, - } + }; + debug_assert!(ctx.authenticated && ctx.user_id.is_some()); + ctx } /// Create an authenticated context without requiring a UUID user ID. @@ -467,13 +483,15 @@ impl AuthContext { roles: Vec, claims: HashMap, ) -> Self { - Self { + let ctx = Self { user_id: None, roles, claims, authenticated: true, token_exp: None, - } + }; + debug_assert!(ctx.authenticated && ctx.user_id.is_none()); + ctx } /// Attach the JWT expiry timestamp to this context. @@ -853,7 +871,9 @@ pub struct MutationContext { pub request: RequestMetadata, db_pool: sqlx::PgPool, http_client: CircuitBreakerClient, - /// `None` means unlimited. + /// `None` means "apply the default ceiling" ([`DEFAULT_HTTP_TIMEOUT`]). + /// A caller that genuinely needs an unbounded outbound request should + /// build its own [`reqwest::Client`] outside the framework. http_timeout: Option, job_dispatch: Option>, workflow_dispatch: Option>, @@ -867,7 +887,6 @@ pub struct MutationContext { /// 0 = unlimited. max_jobs_per_request: usize, kv: Option>, - email_sender: Option>, } impl MutationContext { @@ -888,7 +907,6 @@ impl MutationContext { dispatched_job_count: Arc::new(AtomicUsize::new(0)), max_jobs_per_request: 0, kv: None, - email_sender: None, } } @@ -916,7 +934,6 @@ impl MutationContext { dispatched_job_count: Arc::new(AtomicUsize::new(0)), max_jobs_per_request: 0, kv: None, - email_sender: None, } } @@ -945,7 +962,6 @@ impl MutationContext { dispatched_job_count: Arc::new(AtomicUsize::new(0)), max_jobs_per_request: 0, kv: None, - email_sender: None, } } @@ -986,7 +1002,6 @@ impl MutationContext { dispatched_job_count: Arc::new(AtomicUsize::new(0)), max_jobs_per_request: 0, kv: None, - email_sender: None, }; (ctx, tx_handle) @@ -1009,18 +1024,6 @@ impl MutationContext { .ok_or_else(|| crate::error::ForgeError::internal("KV store not available")) } - /// Attach an email sender. - pub fn set_email(&mut self, sender: Arc) { - self.email_sender = Some(sender); - } - - /// Access the email sender. - pub fn email(&self) -> crate::error::Result<&dyn crate::email::EmailSender> { - self.email_sender - .as_deref() - .ok_or_else(|| crate::error::ForgeError::internal("Email not configured")) - } - pub fn is_transactional(&self) -> bool { self.tx.is_some() } @@ -1045,14 +1048,15 @@ impl MutationContext { /// Direct pool access that **bypasses the active transaction**. /// - /// In a transactional mutation, this returns the raw [`sqlx::PgPool`] and - /// any queries run on it execute outside the transaction — so they will - /// not see uncommitted writes and will not be rolled back if the mutation - /// fails. Prefer [`MutationContext::conn`] or [`MutationContext::db`] for - /// anything that should participate in the transaction. Reach for this - /// only for operations that fundamentally cannot run inside a transaction - /// (e.g. `LISTEN`/`NOTIFY`, advisory locks, or background pool work). - pub fn bypass_pool(&self) -> &sqlx::PgPool { + /// WARNING: in a transactional mutation, this returns the raw + /// [`sqlx::PgPool`] and any queries run on it execute outside the + /// transaction — they will not see uncommitted writes and will not be + /// rolled back if the mutation fails. Prefer + /// [`MutationContext::conn`] or [`MutationContext::db`] for anything + /// that should participate in the transaction. Reach for this only for + /// operations that fundamentally cannot run inside a transaction (e.g. + /// `LISTEN`/`NOTIFY`, advisory locks, or background pool work). + pub fn pool_outside_transaction(&self) -> &sqlx::PgPool { &self.db_pool } @@ -1087,10 +1091,15 @@ impl MutationContext { /// declared an explicit `timeout`, that timeout is also applied to outbound /// HTTP requests unless the request overrides it. pub fn http(&self) -> crate::http::HttpClient { - self.http_client.with_timeout(self.http_timeout) + let timeout = self.http_timeout.or(Some(DEFAULT_HTTP_TIMEOUT)); + self.http_client.with_timeout(timeout) } - /// Get the raw reqwest client, bypassing circuit breaker execution. + /// Get the raw reqwest client, bypassing circuit breaker execution, + /// host blocklist, and retries. + /// + /// Gated behind the `escape-hatches` feature; prefer [`Self::http`]. + #[cfg(feature = "escape-hatches")] pub fn raw_http(&self) -> &reqwest::Client { self.http_client.inner() } @@ -1133,6 +1142,35 @@ impl MutationContext { self.max_jobs_per_request = limit; } + /// Atomically reserve a job-dispatch slot under `max_jobs_per_request`. + /// + /// Uses a `compare_exchange` loop so concurrent dispatches (e.g. via + /// `join_all`) cannot briefly exceed the limit. Returns an error when + /// the cap has been reached. + fn reserve_job_slot(&self) -> crate::error::Result<()> { + if self.max_jobs_per_request == 0 { + return Ok(()); + } + let mut current = self.dispatched_job_count.load(Ordering::Acquire); + loop { + if current >= self.max_jobs_per_request { + return Err(crate::error::ForgeError::Validation(format!( + "max_jobs_per_request limit of {} exceeded", + self.max_jobs_per_request + ))); + } + match self.dispatched_job_count.compare_exchange( + current, + current + 1, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => return Ok(()), + Err(observed) => current = observed, + } + } + } + /// Issue a signed JWT from the given claims. /// /// Only available when HMAC auth is configured in `forge.toml`. @@ -1252,18 +1290,7 @@ impl MutationContext { job_type: &str, args: T, ) -> crate::error::Result { - if self.max_jobs_per_request > 0 { - let count = self.dispatched_job_count.fetch_add(1, Ordering::Relaxed); - if count >= self.max_jobs_per_request { - // Undo the increment so repeated calls after the limit give a - // consistent count rather than growing without bound. - self.dispatched_job_count.fetch_sub(1, Ordering::Relaxed); - return Err(crate::error::ForgeError::Validation(format!( - "max_jobs_per_request limit of {} exceeded", - self.max_jobs_per_request - ))); - } - } + self.reserve_job_slot()?; let args_json = serde_json::to_value(args)?; let dispatcher = self @@ -1312,16 +1339,7 @@ impl MutationContext { args: T, scheduled_at: DateTime, ) -> crate::error::Result { - if self.max_jobs_per_request > 0 { - let count = self.dispatched_job_count.fetch_add(1, Ordering::Relaxed); - if count >= self.max_jobs_per_request { - self.dispatched_job_count.fetch_sub(1, Ordering::Relaxed); - return Err(crate::error::ForgeError::Validation(format!( - "max_jobs_per_request limit of {} exceeded", - self.max_jobs_per_request - ))); - } - } + self.reserve_job_slot()?; let args_json = serde_json::to_value(args)?; let dispatcher = self diff --git a/crates/forge-core/src/job/traits.rs b/crates/forge-core/src/job/traits.rs index 52a3bf99..326efb8d 100644 --- a/crates/forge-core/src/job/traits.rs +++ b/crates/forge-core/src/job/traits.rs @@ -211,14 +211,51 @@ impl Default for RetryConfig { } impl RetryConfig { + /// Compute the retry delay for `attempt`. Adds ±25% jitter to the base + /// strategy so a fleet of jobs retrying after the same upstream outage + /// doesn't align to the same wall-clock second and re-thunder the + /// recovering dependency (#14 in issues doc). Also uses `checked_pow` + /// to avoid overflow at attempt 33+ (#21 in issues doc). pub fn calculate_backoff(&self, attempt: u32) -> Duration { let base = Duration::from_secs(1); - let backoff = match self.backoff { + let base_backoff = match self.backoff { BackoffStrategy::Fixed => base, - BackoffStrategy::Linear => base * attempt, - BackoffStrategy::Exponential => base * 2u32.pow(attempt.saturating_sub(1)), + BackoffStrategy::Linear => base.saturating_mul(attempt.max(1)), + BackoffStrategy::Exponential => { + let exp = attempt.saturating_sub(1); + let factor = 2u32.checked_pow(exp).unwrap_or(u32::MAX); + base.saturating_mul(factor) + } }; - backoff.min(self.max_backoff) + let capped = base_backoff.min(self.max_backoff); + Self::apply_jitter(capped) + } + + /// Apply ±25% jitter using a process-wide nanosecond clock as entropy. + /// No `rand`/`fastrand` dependency in the workspace; an Instant-derived + /// LCG is sufficient for desynchronizing retries. + fn apply_jitter(d: Duration) -> Duration { + let nanos = d.as_nanos(); + if nanos == 0 { + return d; + } + let now_ns = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|x| x.as_nanos()) + .unwrap_or(0); + // Stretch entropy with a small LCG step so adjacent calls don't return + // identical jitter. + let mixed = now_ns + .wrapping_mul(6364136223846793005) + .wrapping_add(1442695040888963407); + // Map to [-0.25, +0.25] of the duration as a u128 fraction. + let bucket = (mixed % 1000) as i128 - 500; // -500..=499 + let delta_nanos = (nanos as i128) * bucket / 2000; // ±25% + let adjusted = (nanos as i128).saturating_add(delta_nanos); + let adjusted_u: u128 = adjusted.max(0) as u128; + // Cap at u64::MAX nanos (≈ 584 years) to fit Duration::from_nanos. + let capped = adjusted_u.min(u64::MAX as u128) as u64; + Duration::from_nanos(capped) } } @@ -272,20 +309,37 @@ mod tests { #[test] fn test_exponential_backoff() { + // Backoff now applies ±25% jitter; assert bounds rather than exact + // equality. Base values: 1s, 2s, 4s, 8s. let config = RetryConfig::default(); - assert_eq!(config.calculate_backoff(1), Duration::from_secs(1)); - assert_eq!(config.calculate_backoff(2), Duration::from_secs(2)); - assert_eq!(config.calculate_backoff(3), Duration::from_secs(4)); - assert_eq!(config.calculate_backoff(4), Duration::from_secs(8)); + for (attempt, base_ms) in [(1u32, 1000u128), (2, 2000), (3, 4000), (4, 8000)] { + let got = config.calculate_backoff(attempt).as_millis(); + let lo = base_ms * 75 / 100; + let hi = base_ms * 125 / 100; + assert!( + got >= lo && got <= hi, + "attempt {attempt}: expected {lo}..={hi}ms, got {got}ms" + ); + } } #[test] fn test_max_backoff_cap() { + // The cap applies before jitter, so the returned value is in + // [0.75*cap, 1.25*cap]. + let cap = Duration::from_secs(10); let config = RetryConfig { - max_backoff: Duration::from_secs(10), + max_backoff: cap, ..Default::default() }; - assert_eq!(config.calculate_backoff(10), Duration::from_secs(10)); + let got = config.calculate_backoff(10).as_millis(); + let cap_ms = cap.as_millis(); + let lo = cap_ms * 75 / 100; + let hi = cap_ms * 125 / 100; + assert!( + got >= lo && got <= hi, + "expected {lo}..={hi}ms (cap ±25%), got {got}ms" + ); } #[test] @@ -402,6 +456,18 @@ mod tests { assert!(cfg.retry_on.is_empty(), "empty list ⇒ retry on every error"); } + // calculate_backoff applies ±25% jitter; assert bounds, not equality. + fn assert_within_jitter(actual: Duration, target: Duration) { + let target_ms = target.as_millis() as i128; + let actual_ms = actual.as_millis() as i128; + let low = target_ms * 75 / 100; + let high = target_ms * 125 / 100; + assert!( + actual_ms >= low && actual_ms <= high, + "{actual:?} outside ±25% of {target:?}" + ); + } + #[test] fn backoff_fixed_returns_base_for_any_attempt() { let cfg = RetryConfig { @@ -409,7 +475,7 @@ mod tests { ..Default::default() }; for attempt in [1u32, 2, 5, 100] { - assert_eq!(cfg.calculate_backoff(attempt), Duration::from_secs(1)); + assert_within_jitter(cfg.calculate_backoff(attempt), Duration::from_secs(1)); } } @@ -419,23 +485,24 @@ mod tests { backoff: BackoffStrategy::Linear, ..Default::default() }; - assert_eq!(cfg.calculate_backoff(1), Duration::from_secs(1)); - assert_eq!(cfg.calculate_backoff(5), Duration::from_secs(5)); - assert_eq!(cfg.calculate_backoff(50), Duration::from_secs(50)); + assert_within_jitter(cfg.calculate_backoff(1), Duration::from_secs(1)); + assert_within_jitter(cfg.calculate_backoff(5), Duration::from_secs(5)); + assert_within_jitter(cfg.calculate_backoff(50), Duration::from_secs(50)); } #[test] fn backoff_exponential_handles_attempt_zero_without_underflow() { // attempt = 0 ⇒ saturating_sub keeps exponent at 0 ⇒ 2^0 = 1 ⇒ base. let cfg = RetryConfig::default(); - assert_eq!(cfg.calculate_backoff(0), Duration::from_secs(1)); + assert_within_jitter(cfg.calculate_backoff(0), Duration::from_secs(1)); } #[test] fn backoff_exponential_caps_at_max_backoff_for_large_attempt() { - // attempt = 20 ⇒ 2^19 seconds = ~6 days, must cap to default 5 min. + // attempt = 20 saturates above max_backoff (300s); jitter pulls it + // down by up to 25%. let cfg = RetryConfig::default(); - assert_eq!(cfg.calculate_backoff(20), Duration::from_secs(300)); + assert_within_jitter(cfg.calculate_backoff(20), Duration::from_secs(300)); } #[test] diff --git a/crates/forge-core/src/lib.rs b/crates/forge-core/src/lib.rs index b8674061..c3d2b460 100644 --- a/crates/forge-core/src/lib.rs +++ b/crates/forge-core/src/lib.rs @@ -10,7 +10,6 @@ pub mod config; pub mod context; pub mod cron; pub mod daemon; -pub mod email; pub mod env; pub mod error; pub mod function; diff --git a/crates/forge-core/src/pagination.rs b/crates/forge-core/src/pagination.rs index 79436397..16d05a54 100644 --- a/crates/forge-core/src/pagination.rs +++ b/crates/forge-core/src/pagination.rs @@ -17,11 +17,20 @@ impl Cursor { Self(value.into()) } - pub fn as_str(&self) -> &str { + /// Internal accessor for serde / wire-format glue. Treat the returned + /// string as opaque: its encoding is an implementation detail and may + /// change between releases. + #[doc(hidden)] + pub fn as_inner_for_serde(&self) -> &str { &self.0 } } +/// Upper bound on items in a single [`Page`]. Helpers that build pages +/// from client-supplied limits should clamp to this value to prevent a +/// caller from extracting an unbounded number of rows. +pub const MAX_PAGE_SIZE: usize = 1000; + /// A page of results with cursor-based navigation. #[derive(Debug, Clone, Serialize, Deserialize)] #[non_exhaustive] @@ -31,7 +40,12 @@ pub struct Page { } impl Page { - pub fn new(items: Vec, page_info: PageInfo) -> Self { + /// Constructs a page, truncating `items` to [`MAX_PAGE_SIZE`] entries. + /// Callers that already enforce a stricter cap can pass shorter vecs. + pub fn new(mut items: Vec, page_info: PageInfo) -> Self { + if items.len() > MAX_PAGE_SIZE { + items.truncate(MAX_PAGE_SIZE); + } Self { items, page_info } } } @@ -44,7 +58,7 @@ pub struct PageInfo { #[serde(skip_serializing_if = "Option::is_none")] pub end_cursor: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub total_count: Option, + pub total_count: Option, } impl PageInfo { @@ -77,7 +91,7 @@ mod tests { PageInfo { has_next_page: true, end_cursor: Some(Cursor::new("abc")), - total_count: Some(10), + total_count: Some(10u64), }, ); let json: serde_json::Value = serde_json::to_value(&page).unwrap(); diff --git a/crates/forge-core/src/realtime/subscription.rs b/crates/forge-core/src/realtime/subscription.rs index 07f9d556..5da5773d 100644 --- a/crates/forge-core/src/realtime/subscription.rs +++ b/crates/forge-core/src/realtime/subscription.rs @@ -77,6 +77,11 @@ pub struct AuthScope { pub tenant_id: Option, /// Hash of the sorted roles for this auth context. pub role_hash: u64, + /// Admin status is part of the dedup key: `check_owner_access` treats + /// admins as having implicit access to every owner's data, so an admin + /// must never share a group with a non-admin even when principal/tenant/ + /// roles coincide. + pub is_admin: bool, } impl PartialEq for AuthScope { @@ -84,6 +89,7 @@ impl PartialEq for AuthScope { self.principal_id == other.principal_id && self.tenant_id == other.tenant_id && self.role_hash == other.role_hash + && self.is_admin == other.is_admin } } @@ -94,6 +100,7 @@ impl std::hash::Hash for AuthScope { self.principal_id.hash(state); self.tenant_id.hash(state); self.role_hash.hash(state); + self.is_admin.hash(state); } } @@ -114,6 +121,7 @@ impl AuthScope { .and_then(|v| v.as_str()) .map(ToString::to_string), role_hash, + is_admin: auth.is_admin(), } } } @@ -186,14 +194,37 @@ impl QueryGroup { use std::hash::Hash; match value { serde_json::Value::Object(map) => { + hasher.write_u8(b'o'); let mut keys: Vec<&String> = map.keys().collect(); keys.sort(); for key in keys { key.hash(hasher); Self::hash_json_canonical(&map[key], hasher); } + hasher.write_u8(b'e'); + } + serde_json::Value::Array(items) => { + hasher.write_u8(b'a'); + for item in items { + Self::hash_json_canonical(item, hasher); + } + hasher.write_u8(b'e'); + } + serde_json::Value::Null => { + hasher.write_u8(b'0'); + } + serde_json::Value::Bool(b) => { + hasher.write_u8(b'b'); + b.hash(hasher); + } + serde_json::Value::Number(n) => { + hasher.write_u8(b'n'); + n.to_string().hash(hasher); + } + serde_json::Value::String(s) => { + hasher.write_u8(b's'); + s.hash(hasher); } - other => other.to_string().hash(hasher), } } @@ -221,8 +252,11 @@ impl QueryGroup { /// Uses the runtime read set when populated, otherwise falls back to the /// compile-time table dependencies from macro extraction. pub fn should_invalidate(&self, change: &super::readset::Change) -> bool { + let in_compile_deps = self.table_deps.iter().any(|t| *t == change.table); + let in_runtime_set = self.read_set.tables.iter().any(|t| t == &change.table); + let table_matches = if self.read_set.tables.is_empty() { - self.table_deps.iter().any(|t| *t == change.table) + in_compile_deps } else { change.invalidates(&self.read_set) }; @@ -231,7 +265,14 @@ impl QueryGroup { return false; } - if !change.invalidates_columns(self.selected_cols) { + // `selected_cols` is captured at compile time for the macro-declared + // tables. Runtime-discovered tables (added by the manager when the + // read_set widens after execution) don't have a per-table column + // map, so applying the compile-time column filter to them would + // wrongly suppress real changes. Fall back to "always invalidate" + // for those, and only apply the column filter when the change came + // through a compile-time-declared table. + if in_compile_deps && !in_runtime_set && !change.invalidates_columns(self.selected_cols) { return false; } @@ -385,6 +426,7 @@ mod tests { principal_id: Some("user-1".to_string()), tenant_id: None, role_hash: 0, + is_admin: false, }; let key1 = QueryGroup::compute_lookup_key( "get_projects", @@ -402,6 +444,7 @@ mod tests { principal_id: Some("user-2".to_string()), tenant_id: None, role_hash: 0, + is_admin: false, }; let key3 = QueryGroup::compute_lookup_key( "get_projects", @@ -417,6 +460,7 @@ mod tests { principal_id: Some("u1".to_string()), tenant_id: None, role_hash: 0, + is_admin: false, }; let key = QueryGroup::compute_lookup_key("get_items", &serde_json::json!({"id": "42"}), &scope); @@ -433,6 +477,7 @@ mod tests { principal_id: None, tenant_id: None, role_hash: 0, + is_admin: false, }; let key_ab = QueryGroup::compute_lookup_key("q", &serde_json::json!({"a": 1, "b": 2}), &scope); diff --git a/crates/forge-core/src/tenant/mod.rs b/crates/forge-core/src/tenant/mod.rs index a7f67a53..d46e9ee0 100644 --- a/crates/forge-core/src/tenant/mod.rs +++ b/crates/forge-core/src/tenant/mod.rs @@ -142,4 +142,45 @@ mod tests { let ctx = TenantContext::strict(tenant_id); assert!(ctx.require_tenant().is_ok()); } + + #[test] + fn sql_filter_rejects_injection_attempts() { + let ctx = TenantContext::strict(Uuid::new_v4()); + // Anything outside [A-Za-z0-9_] must be refused so the column name can + // never carry SQL. Empty is rejected too. + for bad in [ + "", + "tenant_id; DROP TABLE users", + "tenant_id OR 1=1", + "tenant_id--", + "tenant\"_id", + "tenant id", + "tenant.id", + "tenant_id)", + "té", + ] { + assert!( + ctx.sql_filter(bad, 1).is_none(), + "column {bad:?} should be rejected" + ); + } + } + + #[test] + fn sql_filter_accepts_valid_identifiers_and_quotes_them() { + let tenant_id = Uuid::new_v4(); + let ctx = TenantContext::strict(tenant_id); + let (clause, id) = ctx + .sql_filter("org_id_2", 5) + .expect("alphanumeric+underscore column is valid"); + assert_eq!(clause, "\"org_id_2\" = $5"); + assert_eq!(id, tenant_id); + } + + #[test] + fn sql_filter_returns_none_without_tenant_even_for_valid_column() { + // No tenant id => nothing to scope by, regardless of column validity. + let ctx = TenantContext::none(); + assert!(ctx.sql_filter("tenant_id", 1).is_none()); + } } diff --git a/crates/forge-core/src/testing/assertions.rs b/crates/forge-core/src/testing/assertions.rs index ed364687..d1dee025 100644 --- a/crates/forge-core/src/testing/assertions.rs +++ b/crates/forge-core/src/testing/assertions.rs @@ -274,127 +274,3 @@ where { items.iter().any(predicate) } - -#[cfg(test)] -mod tests { - use super::{assert_contains, assert_json_matches}; - use crate::error::ForgeError; - - #[test] - fn test_assert_ok_macro() { - let result: Result = Ok(42); - assert_ok!(result); - } - - #[test] - #[should_panic(expected = "expected Ok")] - fn test_assert_ok_macro_fails() { - let result: Result = Err("error".to_string()); - assert_ok!(result); - } - - #[test] - fn test_assert_err_macro() { - let result: Result = Err("error".to_string()); - assert_err!(result); - } - - #[test] - #[should_panic(expected = "expected Err")] - fn test_assert_err_macro_fails() { - let result: Result = Ok(42); - assert_err!(result); - } - - #[test] - fn test_assert_err_variant() { - let result: Result<(), ForgeError> = Err(ForgeError::NotFound("user".into())); - assert_err_variant!(result, ForgeError::NotFound(_)); - } - - #[test] - #[should_panic(expected = "expected ForgeError::Unauthorized(_)")] - fn test_assert_err_variant_wrong_variant() { - let result: Result<(), ForgeError> = Err(ForgeError::NotFound("user".into())); - assert_err_variant!(result, ForgeError::Unauthorized(_)); - } - - #[test] - fn test_assert_err_matches_no_guard() { - let result: Result<(), ForgeError> = - Err(ForgeError::Validation("email is required".into())); - assert_err_matches!(result, ForgeError::Validation(_)); - } - - #[test] - #[allow(unused_variables)] - fn test_assert_err_matches_with_guard() { - let result: Result<(), ForgeError> = - Err(ForgeError::Validation("email is required".into())); - assert_err_matches!(result, ForgeError::Validation(msg) if msg.contains("email")); - } - - #[test] - #[should_panic(expected = "guard failed")] - #[allow(unused_variables)] - fn test_assert_err_matches_guard_fails() { - let result: Result<(), ForgeError> = - Err(ForgeError::Validation("email is required".into())); - assert_err_matches!(result, ForgeError::Validation(msg) if msg.contains("password")); - } - - #[test] - #[should_panic(expected = "expected ForgeError::Unauthorized(_)")] - fn test_assert_err_matches_wrong_variant() { - let result: Result<(), ForgeError> = Err(ForgeError::NotFound("user".into())); - assert_err_matches!(result, ForgeError::Unauthorized(_)); - } - - #[test] - fn test_assert_json_matches() { - let actual = serde_json::json!({ - "id": 123, - "name": "Test", - "nested": { - "foo": "bar" - } - }); - - assert!(assert_json_matches( - &actual, - &serde_json::json!({"id": 123}) - )); - assert!(assert_json_matches( - &actual, - &serde_json::json!({"name": "Test"}) - )); - assert!(assert_json_matches( - &actual, - &serde_json::json!({"nested": {"foo": "bar"}}) - )); - - assert!(!assert_json_matches( - &actual, - &serde_json::json!({"id": 456}) - )); - assert!(!assert_json_matches( - &actual, - &serde_json::json!({"missing": true}) - )); - } - - #[test] - fn test_assert_json_matches_arrays() { - let actual = serde_json::json!([1, 2, 3]); - assert!(assert_json_matches(&actual, &serde_json::json!([1, 2, 3]))); - assert!(!assert_json_matches(&actual, &serde_json::json!([1, 2]))); - assert!(!assert_json_matches(&actual, &serde_json::json!([1, 2, 4]))); - } - - #[test] - fn test_assert_contains() { - let items = vec![1, 2, 3, 4, 5]; - assert!(assert_contains(&items, |x| *x == 3)); - assert!(!assert_contains(&items, |x| *x == 6)); - } -} diff --git a/crates/forge-core/src/testing/context/mcp_tool.rs b/crates/forge-core/src/testing/context/mcp_tool.rs index 78f075fb..101fefd0 100644 --- a/crates/forge-core/src/testing/context/mcp_tool.rs +++ b/crates/forge-core/src/testing/context/mcp_tool.rs @@ -137,8 +137,17 @@ impl TestMcpToolContextBuilder { self } + /// Set the tenant id for multi-tenant testing. + /// + /// Production code reads the tenant from `auth.claims["tenant_id"]`, so + /// this writes the same value into the claims map. Tests calling + /// `ctx.auth.tenant_id()` then behave identically to production. pub fn with_tenant(mut self, tenant_id: Uuid) -> Self { self.tenant_id = Some(tenant_id); + self.claims.insert( + "tenant_id".to_string(), + serde_json::Value::String(tenant_id.to_string()), + ); self } diff --git a/crates/forge-core/src/testing/context/query.rs b/crates/forge-core/src/testing/context/query.rs index 15024cd0..51a41637 100644 --- a/crates/forge-core/src/testing/context/query.rs +++ b/crates/forge-core/src/testing/context/query.rs @@ -128,8 +128,16 @@ impl TestQueryContextBuilder { } /// Set the tenant ID for multi-tenant testing. + /// + /// Production code reads the tenant from `auth.claims["tenant_id"]`, so + /// this writes the same value into the claims map. Tests calling + /// `ctx.auth.tenant_id()` then behave identically to production. pub fn with_tenant(mut self, tenant_id: Uuid) -> Self { self.tenant_id = Some(tenant_id); + self.claims.insert( + "tenant_id".to_string(), + serde_json::Value::String(tenant_id.to_string()), + ); self } @@ -167,45 +175,6 @@ impl TestQueryContextBuilder { mod tests { use super::*; - #[test] - fn test_minimal_context() { - let ctx = TestQueryContext::minimal(); - assert!(!ctx.auth.is_authenticated()); - assert!(ctx.db().is_none()); - } - - #[test] - fn test_authenticated_context() { - let user_id = Uuid::new_v4(); - let ctx = TestQueryContext::authenticated(user_id); - assert!(ctx.auth.is_authenticated()); - assert_eq!(ctx.user_id().unwrap(), user_id); - } - - #[test] - fn test_context_with_roles() { - let ctx = TestQueryContext::builder() - .as_user(Uuid::new_v4()) - .with_role("admin") - .with_role("user") - .build(); - - assert!(ctx.has_role("admin")); - assert!(ctx.has_role("user")); - assert!(!ctx.has_role("superuser")); - } - - #[test] - fn test_context_with_claims() { - let ctx = TestQueryContext::builder() - .as_user(Uuid::new_v4()) - .with_claim("org_id", serde_json::json!("org-123")) - .build(); - - assert_eq!(ctx.claim("org_id"), Some(&serde_json::json!("org-123"))); - assert!(ctx.claim("nonexistent").is_none()); - } - #[test] fn test_context_with_env() { let ctx = TestQueryContext::builder() diff --git a/crates/forge-core/src/testing/context/workflow.rs b/crates/forge-core/src/testing/context/workflow.rs index 2f8b5ceb..d7fa9a2e 100644 --- a/crates/forge-core/src/testing/context/workflow.rs +++ b/crates/forge-core/src/testing/context/workflow.rs @@ -275,8 +275,15 @@ impl TestWorkflowContextBuilder { } /// Set the tenant ID. + /// + /// Production reads the tenant from `auth.claims["tenant_id"]`, so this + /// writes the same value into the claims map. pub fn with_tenant(mut self, tenant_id: Uuid) -> Self { self.tenant_id = Some(tenant_id); + self.claims.insert( + "tenant_id".to_string(), + serde_json::Value::String(tenant_id.to_string()), + ); self } diff --git a/crates/forge-core/src/testing/db.rs b/crates/forge-core/src/testing/db.rs index 57585343..3d4bfa2c 100644 --- a/crates/forge-core/src/testing/db.rs +++ b/crates/forge-core/src/testing/db.rs @@ -119,10 +119,16 @@ impl TestDatabase { /// Create a dedicated database for a single test, providing full isolation. pub async fn isolated(&self, test_name: &str) -> Result { let base_url = self.url.clone(); + // Cap the final identifier well under Postgres' 63-char limit so two + // tests with the same long prefix never collide on a truncated name. + // Layout: `forge_test_` (11) + sanitized name (<=16) + `_` (1) + + // 8 hex chars of a UUID = 36 chars total. + let uuid_hex = uuid::Uuid::new_v4().simple().to_string(); + let short_uuid: String = uuid_hex.chars().take(8).collect(); let db_name = format!( "forge_test_{}_{}", - sanitize_db_name(test_name), - uuid::Uuid::new_v4().simple() + sanitize_db_name_short(test_name), + short_uuid ); sqlx::query(&format!("CREATE DATABASE \"{}\"", db_name)) @@ -139,7 +145,7 @@ impl TestDatabase { .map_err(ForgeError::Database)?; Ok(IsolatedTestDb { - pool: test_pool, + pool: Some(test_pool), db_name, base_url, #[cfg(feature = "testcontainers")] @@ -148,10 +154,15 @@ impl TestDatabase { } } -/// A test database scoped to a single test. Call `cleanup()` to drop it immediately, -/// or rely on future test runs to clean up orphaned databases. +/// A test database scoped to a single test. +/// +/// Cleanup happens in `Drop`: the pool is closed and `DROP DATABASE` is fired +/// on a fresh sync connection via `tokio::task::block_in_place` + +/// `Handle::current().block_on()`. Tests that want to surface cleanup errors +/// can call [`IsolatedTestDb::cleanup`] (async) explicitly instead — `Drop` +/// then becomes a no-op. pub struct IsolatedTestDb { - pool: PgPool, + pool: Option, db_name: String, base_url: String, #[cfg(feature = "testcontainers")] @@ -160,16 +171,24 @@ pub struct IsolatedTestDb { impl IsolatedTestDb { /// Convenience: `from_env()` → `isolated()` → `run_sql(internal_sql)` → `migrate()`. + /// + /// On a partial failure (system SQL or user migrations), the freshly-created + /// database is dropped via the standard `Drop` path of the guard struct — + /// the caller never observes a leaked database. pub async fn setup(test_name: &str, internal_sql: &str, migrations_dir: &Path) -> Result { let base = TestDatabase::from_env().await?; let db = base.isolated(test_name).await?; + // The half-built db is owned by `db`; if either step below returns + // early, `db`'s Drop fires and the database is dropped. db.run_sql(internal_sql).await?; db.migrate(migrations_dir).await?; Ok(db) } pub fn pool(&self) -> &PgPool { - &self.pool + self.pool + .as_ref() + .expect("IsolatedTestDb pool is taken only during Drop/cleanup") } pub fn db_name(&self) -> &str { @@ -179,7 +198,7 @@ impl IsolatedTestDb { /// Run raw SQL for test setup. pub async fn execute(&self, sql: &str) -> Result<()> { sqlx::query(sql) - .execute(&self.pool) + .execute(self.pool()) .await .map_err(ForgeError::Database)?; Ok(()) @@ -193,7 +212,7 @@ impl IsolatedTestDb { continue; } sqlx::query(stmt) - .execute(&self.pool) + .execute(self.pool()) .await .map_err(|e| ForgeError::internal_with("Failed to execute SQL", e))?; } @@ -201,32 +220,97 @@ impl IsolatedTestDb { } /// Drop the isolated database and close all connections. - pub async fn cleanup(self) -> Result<()> { - self.pool.close().await; + /// + /// Calling this disarms the `Drop` impl — useful for tests that want + /// cleanup errors to surface rather than being logged. + pub async fn cleanup(mut self) -> Result<()> { + let pool = match self.pool.take() { + Some(p) => p, + None => return Ok(()), + }; + drop_db_async(pool, &self.base_url, &self.db_name).await + } +} - let pool = sqlx::postgres::PgPoolOptions::new() - .max_connections(1) - .connect(&self.base_url) - .await - .map_err(ForgeError::Database)?; +async fn drop_db_async(pool: PgPool, base_url: &str, db_name: &str) -> Result<()> { + pool.close().await; - if let Err(e) = - sqlx::query("SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = $1") - .bind(&self.db_name) - .execute(&pool) - .await - { - tracing::warn!(db = %self.db_name, error = %e, "failed to terminate backend connections during test cleanup"); - } + let admin_pool = sqlx::postgres::PgPoolOptions::new() + .max_connections(1) + .connect(base_url) + .await + .map_err(ForgeError::Database)?; - sqlx::query(&format!("DROP DATABASE IF EXISTS \"{}\"", self.db_name)) - .execute(&pool) + if let Err(e) = + sqlx::query("SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = $1") + .bind(db_name) + .execute(&admin_pool) .await - .map_err(ForgeError::Database)?; + { + tracing::warn!(db = %db_name, error = %e, "failed to terminate backend connections during test cleanup"); + } - Ok(()) + sqlx::query(&format!("DROP DATABASE IF EXISTS \"{}\"", db_name)) + .execute(&admin_pool) + .await + .map_err(ForgeError::Database)?; + + Ok(()) +} + +impl Drop for IsolatedTestDb { + fn drop(&mut self) { + let Some(pool) = self.pool.take() else { + return; + }; + let base_url = self.base_url.clone(); + let db_name = self.db_name.clone(); + + // The runtime flavor decides how we drive the async cleanup: + // - multi_thread: `block_in_place` releases the worker so a nested + // `block_on` is safe. + // - current_thread: `block_in_place` panics; we instead spawn the + // cleanup as a detached task on the existing handle. The runtime + // drives it to completion before the process exits as long as the + // test runtime outlives this Drop (true for `#[tokio::test]` since + // the runtime owns the future). + // - no runtime: nothing we can do; log and leak. + match tokio::runtime::Handle::try_current() { + Ok(handle) => match handle.runtime_flavor() { + tokio::runtime::RuntimeFlavor::MultiThread => { + tokio::task::block_in_place(|| { + if let Err(e) = handle.block_on(drop_db_async(pool, &base_url, &db_name)) { + tracing::warn!( + db = %db_name, + error = %e, + "IsolatedTestDb::drop failed to clean up; database leaked" + ); + } + }); + } + _ => { + handle.spawn(async move { + if let Err(e) = drop_db_async(pool, &base_url, &db_name).await { + tracing::warn!( + db = %db_name, + error = %e, + "IsolatedTestDb::drop failed to clean up; database leaked" + ); + } + }); + } + }, + Err(_) => { + tracing::warn!( + db = %db_name, + "IsolatedTestDb dropped outside a tokio runtime; database leaked" + ); + } + } } +} +impl IsolatedTestDb { /// Run migrations: loads all `.sql` files from the directory, sorts alphabetically, executes in order. pub async fn migrate(&self, migrations_dir: &Path) -> Result<()> { if !migrations_dir.exists() { @@ -268,7 +352,7 @@ impl IsolatedTestDb { if is_blank_sql(stmt) { continue; } - sqlx::query(stmt).execute(&self.pool).await.map_err(|e| { + sqlx::query(stmt).execute(self.pool()).await.map_err(|e| { ForgeError::internal(format!("Failed to apply migration '{name}': {e}")) })?; } @@ -285,10 +369,14 @@ fn is_blank_sql(sql: &str) -> bool { .all(|l| l.trim().is_empty() || l.trim().starts_with("--")) } -fn sanitize_db_name(name: &str) -> String { +/// Sanitize a test name into something that's safe to embed in a Postgres +/// identifier. Capped at 16 characters so the final +/// `forge_test__<8hex>` identifier stays well under Postgres' 63-char +/// identifier limit (11 + 16 + 1 + 8 = 36). +fn sanitize_db_name_short(name: &str) -> String { name.chars() .map(|c| if c.is_alphanumeric() { c } else { '_' }) - .take(32) + .take(16) .collect() } @@ -445,11 +533,11 @@ mod tests { use super::*; #[test] - fn test_sanitize_db_name() { - assert_eq!(sanitize_db_name("my_test"), "my_test"); - assert_eq!(sanitize_db_name("my-test"), "my_test"); - assert_eq!(sanitize_db_name("my test"), "my_test"); - assert_eq!(sanitize_db_name("test::function"), "test__function"); + fn test_sanitize_db_name_short() { + assert_eq!(sanitize_db_name_short("my_test"), "my_test"); + assert_eq!(sanitize_db_name_short("my-test"), "my_test"); + assert_eq!(sanitize_db_name_short("my test"), "my_test"); + assert_eq!(sanitize_db_name_short("test::function"), "test__function"); } #[test] @@ -560,18 +648,21 @@ mod tests { } #[test] - fn sanitize_truncates_long_names() { + fn sanitize_short_caps_at_16() { let long_name = "a".repeat(100); - let sanitized = sanitize_db_name(&long_name); - assert_eq!(sanitized.len(), 32); + let sanitized = sanitize_db_name_short(&long_name); + assert_eq!(sanitized.len(), 16); + // Full identifier: 11 ("forge_test_") + 16 + 1 + 8 = 36, safely <= 63. + let identifier = format!("forge_test_{}_{}", sanitized, "12345678"); + assert!(identifier.len() <= 63); } #[test] fn sanitize_handles_special_characters() { assert_eq!( - sanitize_db_name("test/with:special!chars"), - "test_with_special_chars" + sanitize_db_name_short("test/with:specia"), + "test_with_specia" ); - assert_eq!(sanitize_db_name(""), ""); + assert_eq!(sanitize_db_name_short(""), ""); } } diff --git a/crates/forge-core/src/testing/mock_dispatch.rs b/crates/forge-core/src/testing/mock_dispatch.rs index bec8e4aa..11506d5e 100644 --- a/crates/forge-core/src/testing/mock_dispatch.rs +++ b/crates/forge-core/src/testing/mock_dispatch.rs @@ -87,18 +87,21 @@ impl MockJobDispatch { cancel_reason: None, }; - self.jobs.write().expect("jobs lock poisoned").push(job); + self.jobs + .write() + .unwrap_or_else(|p| p.into_inner()) + .push(job); Ok(id) } pub fn dispatched_jobs(&self) -> Vec { - self.jobs.read().expect("jobs lock poisoned").clone() + self.jobs.read().unwrap_or_else(|p| p.into_inner()).clone() } pub fn jobs_of_type(&self, job_type: &str) -> Vec { self.jobs .read() - .expect("jobs lock poisoned") + .unwrap_or_else(|p| p.into_inner()) .iter() .filter(|j| j.job_type == job_type) .cloned() @@ -106,7 +109,7 @@ impl MockJobDispatch { } pub fn assert_dispatched(&self, job_type: &str) { - let jobs = self.jobs.read().expect("jobs lock poisoned"); + let jobs = self.jobs.read().unwrap_or_else(|p| p.into_inner()); let found = jobs.iter().any(|j| j.job_type == job_type); assert!( found, @@ -116,11 +119,13 @@ impl MockJobDispatch { ); } + /// Lenient: passes when *any* dispatched job with this name matches + /// the predicate. Other unrelated dispatches are ignored. pub fn assert_dispatched_with(&self, job_type: &str, predicate: F) where F: Fn(&serde_json::Value) -> bool, { - let jobs = self.jobs.read().expect("jobs lock poisoned"); + let jobs = self.jobs.read().unwrap_or_else(|p| p.into_inner()); let found = jobs .iter() .any(|j| j.job_type == job_type && predicate(&j.args)); @@ -131,8 +136,37 @@ impl MockJobDispatch { ); } + /// Strict: passes only when *every* dispatched job with this name + /// matches the predicate (and at least one such dispatch exists). + /// Use to assert a precise audience, e.g. "the email job ran for + /// user 5 and nobody else." + pub fn assert_dispatched_with_exact(&self, job_type: &str, predicate: F) + where + F: Fn(&serde_json::Value) -> bool, + { + let jobs = self.jobs.read().unwrap_or_else(|p| p.into_inner()); + let matching: Vec<&DispatchedJob> = + jobs.iter().filter(|j| j.job_type == job_type).collect(); + assert!( + !matching.is_empty(), + "Expected at least one dispatch of '{}', but none were recorded", + job_type + ); + let mismatches: Vec<&serde_json::Value> = matching + .iter() + .filter(|j| !predicate(&j.args)) + .map(|j| &j.args) + .collect(); + assert!( + mismatches.is_empty(), + "Expected every dispatch of '{}' to match predicate; mismatched args: {:?}", + job_type, + mismatches + ); + } + pub fn assert_not_dispatched(&self, job_type: &str) { - let jobs = self.jobs.read().expect("jobs lock poisoned"); + let jobs = self.jobs.read().unwrap_or_else(|p| p.into_inner()); let found = jobs.iter().any(|j| j.job_type == job_type); assert!( !found, @@ -142,7 +176,7 @@ impl MockJobDispatch { } pub fn assert_dispatch_count(&self, job_type: &str, expected: usize) { - let jobs = self.jobs.read().expect("jobs lock poisoned"); + let jobs = self.jobs.read().unwrap_or_else(|p| p.into_inner()); let count = jobs.iter().filter(|j| j.job_type == job_type).count(); assert_eq!( count, expected, @@ -152,28 +186,34 @@ impl MockJobDispatch { } pub fn clear(&self) { - self.jobs.write().expect("jobs lock poisoned").clear(); + self.jobs.write().unwrap_or_else(|p| p.into_inner()).clear(); } pub fn complete_job(&self, job_id: Uuid) { - let mut jobs = self.jobs.write().expect("jobs lock poisoned"); + let mut jobs = self.jobs.write().unwrap_or_else(|p| p.into_inner()); if let Some(job) = jobs.iter_mut().find(|j| j.id == job_id) { job.status = JobStatus::Completed; } } pub fn fail_job(&self, job_id: Uuid) { - let mut jobs = self.jobs.write().expect("jobs lock poisoned"); + let mut jobs = self.jobs.write().unwrap_or_else(|p| p.into_inner()); if let Some(job) = jobs.iter_mut().find(|j| j.id == job_id) { job.status = JobStatus::Failed; } } - pub fn cancel_job(&self, job_id: Uuid, reason: Option) { - let mut jobs = self.jobs.write().expect("jobs lock poisoned"); + /// Returns `true` when a matching dispatched job was found and marked + /// cancelled, `false` otherwise. Mirrors production semantics so tests + /// can assert cancel-of-unknown-id behaviour. + pub fn cancel_job(&self, job_id: Uuid, reason: Option) -> bool { + let mut jobs = self.jobs.write().unwrap_or_else(|p| p.into_inner()); if let Some(job) = jobs.iter_mut().find(|j| j.id == job_id) { job.status = JobStatus::Cancelled; job.cancel_reason = reason; + true + } else { + false } } } @@ -252,10 +292,7 @@ impl crate::function::JobDispatch for MockJobDispatch { job_id: Uuid, reason: Option, ) -> std::pin::Pin> + Send + '_>> { - Box::pin(async move { - self.cancel_job(job_id, reason); - Ok(true) - }) + Box::pin(async move { Ok(self.cancel_job(job_id, reason)) }) } } @@ -286,7 +323,7 @@ impl MockWorkflowDispatch { self.workflows .write() - .expect("workflows lock poisoned") + .unwrap_or_else(|p| p.into_inner()) .push(workflow); Ok(run_id) } @@ -294,14 +331,14 @@ impl MockWorkflowDispatch { pub fn started_workflows(&self) -> Vec { self.workflows .read() - .expect("workflows lock poisoned") + .unwrap_or_else(|p| p.into_inner()) .clone() } pub fn workflows_named(&self, name: &str) -> Vec { self.workflows .read() - .expect("workflows lock poisoned") + .unwrap_or_else(|p| p.into_inner()) .iter() .filter(|w| w.workflow_name == name) .cloned() @@ -309,7 +346,7 @@ impl MockWorkflowDispatch { } pub fn assert_started(&self, workflow_name: &str) { - let workflows = self.workflows.read().expect("workflows lock poisoned"); + let workflows = self.workflows.read().unwrap_or_else(|p| p.into_inner()); let found = workflows.iter().any(|w| w.workflow_name == workflow_name); assert!( found, @@ -326,7 +363,7 @@ impl MockWorkflowDispatch { where F: Fn(&serde_json::Value) -> bool, { - let workflows = self.workflows.read().expect("workflows lock poisoned"); + let workflows = self.workflows.read().unwrap_or_else(|p| p.into_inner()); let found = workflows .iter() .any(|w| w.workflow_name == workflow_name && predicate(&w.input)); @@ -338,7 +375,7 @@ impl MockWorkflowDispatch { } pub fn assert_not_started(&self, workflow_name: &str) { - let workflows = self.workflows.read().expect("workflows lock poisoned"); + let workflows = self.workflows.read().unwrap_or_else(|p| p.into_inner()); let found = workflows.iter().any(|w| w.workflow_name == workflow_name); assert!( !found, @@ -348,7 +385,7 @@ impl MockWorkflowDispatch { } pub fn assert_start_count(&self, workflow_name: &str, expected: usize) { - let workflows = self.workflows.read().expect("workflows lock poisoned"); + let workflows = self.workflows.read().unwrap_or_else(|p| p.into_inner()); let count = workflows .iter() .filter(|w| w.workflow_name == workflow_name) @@ -363,19 +400,19 @@ impl MockWorkflowDispatch { pub fn clear(&self) { self.workflows .write() - .expect("workflows lock poisoned") + .unwrap_or_else(|p| p.into_inner()) .clear(); } pub fn complete_workflow(&self, run_id: Uuid) { - let mut workflows = self.workflows.write().expect("workflows lock poisoned"); + let mut workflows = self.workflows.write().unwrap_or_else(|p| p.into_inner()); if let Some(workflow) = workflows.iter_mut().find(|w| w.run_id == run_id) { workflow.status = WorkflowStatus::Completed; } } pub fn fail_workflow(&self, run_id: Uuid) { - let mut workflows = self.workflows.write().expect("workflows lock poisoned"); + let mut workflows = self.workflows.write().unwrap_or_else(|p| p.into_inner()); if let Some(workflow) = workflows.iter_mut().find(|w| w.run_id == run_id) { workflow.status = WorkflowStatus::Failed; } diff --git a/crates/forge-core/src/testing/mock_email.rs b/crates/forge-core/src/testing/mock_email.rs deleted file mode 100644 index d2913524..00000000 --- a/crates/forge-core/src/testing/mock_email.rs +++ /dev/null @@ -1,72 +0,0 @@ -//! Mock email sender for testing. - -use std::future::Future; -use std::pin::Pin; -use std::sync::Arc; - -use tokio::sync::Mutex; - -use crate::email::{Email, EmailSender}; -use crate::error::Result; - -/// Records sent emails for assertion in tests. -#[derive(Debug, Clone, Default)] -pub struct MockEmailSender { - sent: Arc>>, -} - -/// A recorded email send. -#[derive(Debug, Clone)] -pub struct SentEmail { - pub to: Vec, - pub subject: String, - pub text: Option, - pub html: Option, -} - -impl MockEmailSender { - pub fn new() -> Self { - Self::default() - } - - pub async fn sent(&self) -> Vec { - self.sent.lock().await.clone() - } - - /// Assert that exactly one email was sent to the given address. - pub async fn assert_sent_to(&self, address: &str) { - let sent = self.sent.lock().await; - let matching: Vec<_> = sent - .iter() - .filter(|e| e.to.contains(&address.to_string())) - .collect(); - assert!( - matching.len() == 1, - "Expected 1 email to {address}, found {}", - matching.len() - ); - } - - /// Assert that no emails were sent. - pub async fn assert_none_sent(&self) { - let sent = self.sent.lock().await; - assert!(sent.is_empty(), "Expected no emails, found {}", sent.len()); - } -} - -impl EmailSender for MockEmailSender { - fn send<'a>( - &'a self, - email: &'a Email, - ) -> Pin> + Send + 'a>> { - Box::pin(async move { - self.sent.lock().await.push(SentEmail { - to: email.to.clone(), - subject: email.subject.clone(), - text: email.text.clone(), - html: email.html.clone(), - }); - Ok(format!("mock-{}", uuid::Uuid::new_v4())) - }) - } -} diff --git a/crates/forge-core/src/testing/mock_http.rs b/crates/forge-core/src/testing/mock_http.rs index 5e46ae88..a5ebafae 100644 --- a/crates/forge-core/src/testing/mock_http.rs +++ b/crates/forge-core/src/testing/mock_http.rs @@ -327,281 +327,3 @@ impl Default for MockHttpBuilder { Self::new() } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_mock_response_json() { - let response = MockResponse::json(serde_json::json!({"id": 123})); - assert_eq!(response.status, 200); - assert_eq!(response.body["id"], 123); - } - - #[test] - fn test_mock_response_error() { - let response = MockResponse::error(404, "Not found"); - assert_eq!(response.status, 404); - assert_eq!(response.body["error"], "Not found"); - } - - #[test] - fn test_pattern_matching() { - let mock = MockHttp::new(); - - assert!(mock.matches_pattern( - "https://api.example.com/users", - "https://api.example.com/users" - )); - - assert!(mock.matches_pattern( - "https://api.example.com/users/123", - "https://api.example.com/*" - )); - - assert!(mock.matches_pattern( - "https://api.example.com/v2/users", - "https://api.example.com/*/users" - )); - - assert!(!mock.matches_pattern("https://other.com/users", "https://api.example.com/*")); - } - - #[tokio::test] - async fn test_mock_execution() { - let mock = MockHttp::new(); - mock.add_mock_sync("https://api.example.com/*", |_| { - MockResponse::json(serde_json::json!({"status": "ok"})) - }); - - let request = MockRequest { - method: "GET".to_string(), - path: "/users".to_string(), - url: "https://api.example.com/users".to_string(), - headers: HashMap::new(), - body: serde_json::Value::Null, - }; - - let response = mock.execute(request).await; - assert_eq!(response.status, 200); - assert_eq!(response.body["status"], "ok"); - } - - #[tokio::test] - async fn test_request_recording() { - let mock = MockHttp::new(); - mock.add_mock_sync("*", |_| MockResponse::ok()); - - let request = MockRequest { - method: "POST".to_string(), - path: "/api/users".to_string(), - url: "https://api.example.com/users".to_string(), - headers: HashMap::from([("authorization".to_string(), "Bearer token".to_string())]), - body: serde_json::json!({"name": "Test"}), - }; - - let _ = mock.execute(request).await; - - let requests = mock.requests(); - assert_eq!(requests.len(), 1); - assert_eq!(requests[0].method, "POST"); - assert_eq!(requests[0].body["name"], "Test"); - } - - #[tokio::test] - async fn test_assert_called() { - let mock = MockHttp::new(); - mock.add_mock_sync("*", |_| MockResponse::ok()); - - let request = MockRequest { - method: "GET".to_string(), - path: "/test".to_string(), - url: "https://api.example.com/test".to_string(), - headers: HashMap::new(), - body: serde_json::Value::Null, - }; - - let _ = mock.execute(request).await; - - mock.assert_called("https://api.example.com/*"); - mock.assert_called_times("https://api.example.com/*", 1); - mock.assert_not_called("https://other.com/*"); - } - - #[test] - fn test_builder() { - let mock = MockHttpBuilder::new() - .mock("https://api.example.com/*", |_| MockResponse::ok()) - .mock_json("https://other.com/*", serde_json::json!({"id": 1})) - .build(); - - assert_eq!(mock.mocks.read().unwrap().len(), 2); - } - - fn req(method: &str, url: &str, path: &str) -> MockRequest { - MockRequest { - method: method.to_string(), - path: path.to_string(), - url: url.to_string(), - headers: HashMap::new(), - body: serde_json::Value::Null, - } - } - - #[test] - fn response_status_helpers_use_documented_codes() { - assert_eq!(MockResponse::internal_error("boom").status, 500); - assert_eq!(MockResponse::not_found("nope").status, 404); - assert_eq!(MockResponse::unauthorized("nope").status, 401); - assert_eq!(MockResponse::ok().status, 200); - - // ok() returns an empty JSON object — handlers that key off body shape - // (e.g., serde to () or empty struct) rely on this. - assert_eq!(MockResponse::ok().body, serde_json::json!({})); - } - - #[test] - fn response_json_sets_content_type_header() { - let r = MockResponse::json(serde_json::json!({"ok": true})); - assert_eq!( - r.headers.get("content-type"), - Some(&"application/json".to_string()) - ); - } - - #[test] - fn pattern_matcher_handles_leading_and_double_wildcards() { - let m = MockHttp::new(); - // Leading wildcard (pattern_parts[0] is empty). - assert!(m.matches_pattern("https://api.example.com/v1/users", "*/users")); - assert!(!m.matches_pattern("https://api.example.com/v1/posts", "*/users")); - - // Bare `*` matches anything (both pattern parts are empty strings). - assert!(m.matches_pattern("anything", "*")); - assert!(m.matches_pattern("", "*")); - } - - #[test] - fn pattern_matcher_rejects_exact_pattern_with_extra_suffix() { - let m = MockHttp::new(); - assert!(!m.matches_pattern( - "https://api.example.com/users/extra", - "https://api.example.com/users" - )); - } - - #[tokio::test] - async fn execute_falls_back_to_500_when_no_mock_matches() { - let mock = MockHttp::new(); - let r = mock.execute(req("GET", "https://nowhere/", "/")).await; - assert_eq!(r.status, 500); - assert!( - r.body["error"] - .as_str() - .unwrap_or_default() - .contains("No mock found"), - "fallback should explain the failure, got {:?}", - r.body - ); - } - - #[tokio::test] - async fn execute_records_request_even_when_no_mock_matches() { - // The recording happens before the lookup so failed-match calls still - // show up in requests() — important for diagnosing "why didn't my mock fire". - let mock = MockHttp::new(); - let _ = mock.execute(req("DELETE", "https://nowhere/x", "/x")).await; - let recorded = mock.requests(); - assert_eq!(recorded.len(), 1); - assert_eq!(recorded[0].method, "DELETE"); - assert_eq!(recorded[0].url, "https://nowhere/x"); - } - - #[tokio::test] - async fn execute_matches_against_path_when_url_misses() { - // Pattern only matches the path, not the full URL. - let mock = MockHttp::new(); - mock.add_mock_sync("/health", |_| MockResponse::ok()); - let r = mock - .execute(req("GET", "https://internal.svc:8080/health", "/health")) - .await; - assert_eq!(r.status, 200); - } - - #[tokio::test] - async fn execute_uses_first_registered_mock_on_overlapping_patterns() { - let mock = MockHttp::new(); - mock.add_mock_sync("https://api.example.com/*", |_| { - MockResponse::json(serde_json::json!({"hit": "first"})) - }); - mock.add_mock_sync("https://api.example.com/users", |_| { - MockResponse::json(serde_json::json!({"hit": "second"})) - }); - - let r = mock - .execute(req("GET", "https://api.example.com/users", "/users")) - .await; - assert_eq!(r.body["hit"], "first"); - } - - #[tokio::test] - async fn requests_to_filters_by_pattern() { - let mock = MockHttp::new(); - mock.add_mock_sync("*", |_| MockResponse::ok()); - - let _ = mock - .execute(req("GET", "https://api.example.com/a", "/a")) - .await; - let _ = mock.execute(req("GET", "https://other.com/b", "/b")).await; - let _ = mock - .execute(req("GET", "https://api.example.com/c", "/c")) - .await; - - let api_calls = mock.requests_to("https://api.example.com/*"); - assert_eq!(api_calls.len(), 2); - assert!(api_calls.iter().all(|r| r.url.contains("api.example.com"))); - } - - #[tokio::test] - async fn clear_requests_and_clear_mocks_independently_reset_state() { - let mock = MockHttp::new(); - mock.add_mock_sync("*", |_| MockResponse::ok()); - let _ = mock.execute(req("GET", "https://x/", "/")).await; - assert_eq!(mock.requests().len(), 1); - - mock.clear_requests(); - assert!(mock.requests().is_empty()); - // Mocks survive a requests-only clear; the next call should still match. - let r = mock.execute(req("GET", "https://x/", "/")).await; - assert_eq!(r.status, 200); - - mock.clear_mocks(); - let r = mock.execute(req("GET", "https://x/", "/")).await; - assert_eq!(r.status, 500, "after clear_mocks, fallback should hit"); - } - - #[tokio::test] - async fn assert_called_with_body_runs_predicate_against_recorded_body() { - let mock = MockHttp::new(); - mock.add_mock_sync("*", |_| MockResponse::ok()); - - let mut request = req("POST", "https://api/upload", "/upload"); - request.body = serde_json::json!({"size": 42}); - let _ = mock.execute(request).await; - - // Predicate matches — should not panic. - mock.assert_called_with_body("https://api/*", |body| body["size"] == 42); - } - - #[test] - fn defaults_match_new() { - // Default impls are wrappers around new(); just exercise them so the - // Default path doesn't silently rot. - let m1 = MockHttp::default(); - assert!(m1.requests().is_empty()); - let b1 = MockHttpBuilder::default(); - let m2 = b1.build(); - assert!(m2.requests().is_empty()); - } -} diff --git a/crates/forge-core/src/testing/mod.rs b/crates/forge-core/src/testing/mod.rs index 29b85e99..66553074 100644 --- a/crates/forge-core/src/testing/mod.rs +++ b/crates/forge-core/src/testing/mod.rs @@ -2,14 +2,12 @@ pub mod assertions; pub mod context; pub mod db; pub mod mock_dispatch; -pub mod mock_email; pub mod mock_http; pub use assertions::*; pub use context::*; pub use db::{IsolatedTestDb, TestDatabase}; pub use mock_dispatch::{DispatchedJob, MockJobDispatch, MockWorkflowDispatch, StartedWorkflow}; -pub use mock_email::{MockEmailSender, SentEmail}; pub use mock_http::{MockHttp, MockHttpBuilder, MockRequest, MockResponse}; use std::time::Duration; diff --git a/crates/forge-core/src/util/mod.rs b/crates/forge-core/src/util/mod.rs index 3c99d202..3c418690 100644 --- a/crates/forge-core/src/util/mod.rs +++ b/crates/forge-core/src/util/mod.rs @@ -133,6 +133,62 @@ pub fn to_camel_case(s: &str) -> String { result } +/// Normalize an args/input envelope before deserialization. +/// +/// Job and workflow handlers accept either a bare value or a single-key +/// `{"args": …}` / `{"input": …}` wrapper depending on how the caller phrased +/// the dispatch. This helper unwraps the envelope so the handler's `Args` / +/// `Input` deserialize path doesn't have to special-case both shapes. `null` +/// is collapsed to an empty object so handlers with `()` args still match. +pub fn normalize_handler_args(args: serde_json::Value) -> serde_json::Value { + let unwrapped = match &args { + serde_json::Value::Object(map) if map.len() == 1 => { + if map.contains_key("args") { + map.get("args").cloned().unwrap_or(serde_json::Value::Null) + } else if map.contains_key("input") { + map.get("input").cloned().unwrap_or(serde_json::Value::Null) + } else { + args + } + } + _ => args, + }; + + match &unwrapped { + serde_json::Value::Null => serde_json::Value::Object(serde_json::Map::new()), + _ => unwrapped, + } +} + +/// Extract the bare hostname from an authority component (`host[:port]`), +/// stripping an IPv6 bracket pair and any port. e.g. `[::1]:8080` -> `::1`, +/// `localhost:9081` -> `localhost`, `127.0.0.1` -> `127.0.0.1`. +pub fn hostname_from_authority(authority: &str) -> &str { + match authority.strip_prefix('[') { + // IPv6 literal: the hostname is everything up to the closing bracket. + Some(rest) => rest.split(']').next().unwrap_or(rest), + None => authority.split(':').next().unwrap_or(authority), + } +} + +/// True if `hostname` is a loopback address. Expects a bare hostname with no +/// port or brackets (see [`hostname_from_authority`]). +pub fn is_loopback_host(hostname: &str) -> bool { + matches!(hostname, "localhost" | "127.0.0.1" | "::1") +} + +/// Bare hostname of a plain-`http://` URL (port and IPv6 brackets stripped), +/// or `None` if `url` is not `http://`. +/// +/// Used to decide whether a plain-HTTP endpoint is a safe loopback exception to +/// the HTTPS requirement. A naive `starts_with("http://localhost")` check would +/// wrongly accept `http://localhost.evil.com`, so callers parse the host first. +pub fn http_hostname(url: &str) -> Option<&str> { + let rest = url.strip_prefix("http://")?; + let authority = rest.split(['/', '?', '#']).next().unwrap_or(rest); + Some(hostname_from_authority(authority)) +} + #[cfg(test)] #[allow(clippy::unwrap_used, clippy::indexing_slicing)] mod tests { @@ -269,4 +325,72 @@ mod tests { assert_eq!(to_camel_case("list_all_projects"), "listAllProjects"); assert_eq!(to_camel_case("simple"), "simple"); } + + #[test] + fn normalize_handler_args_converts_null_to_empty_object() { + use serde_json::json; + assert_eq!(normalize_handler_args(json!(null)), json!({})); + } + + #[test] + fn normalize_handler_args_unwraps_args_envelope() { + use serde_json::json; + assert_eq!( + normalize_handler_args(json!({"args": {"x": 1}})), + json!({"x": 1}) + ); + assert_eq!(normalize_handler_args(json!({"args": null})), json!({})); + } + + #[test] + fn normalize_handler_args_unwraps_input_envelope() { + use serde_json::json; + assert_eq!( + normalize_handler_args(json!({"input": [1, 2]})), + json!([1, 2]) + ); + } + + #[test] + fn normalize_handler_args_preserves_other_shapes() { + use serde_json::json; + assert_eq!(normalize_handler_args(json!({"id": 7})), json!({"id": 7})); + assert_eq!(normalize_handler_args(json!([1, 2])), json!([1, 2])); + assert_eq!(normalize_handler_args(json!(42)), json!(42)); + } + + #[test] + fn hostname_from_authority_strips_port_and_brackets() { + assert_eq!(hostname_from_authority("localhost:9081"), "localhost"); + assert_eq!(hostname_from_authority("127.0.0.1"), "127.0.0.1"); + assert_eq!(hostname_from_authority("[::1]:8080"), "::1"); + assert_eq!(hostname_from_authority("[::1]"), "::1"); + assert_eq!(hostname_from_authority("example.com:443"), "example.com"); + } + + #[test] + fn is_loopback_host_matches_only_loopback() { + assert!(is_loopback_host("localhost")); + assert!(is_loopback_host("127.0.0.1")); + assert!(is_loopback_host("::1")); + assert!(!is_loopback_host("localhost.evil.com")); + assert!(!is_loopback_host("example.com")); + } + + #[test] + fn http_hostname_parses_host_and_rejects_spoofs() { + assert_eq!( + http_hostname("http://localhost:9081/jwks"), + Some("localhost") + ); + assert_eq!(http_hostname("http://[::1]:8080/jwks"), Some("::1")); + assert_eq!(http_hostname("http://127.0.0.1/cb?x=1"), Some("127.0.0.1")); + // The classic spoof: a subdomain of localhost is not loopback. + assert_eq!( + http_hostname("http://localhost.evil.com/cb"), + Some("localhost.evil.com") + ); + // Not plain HTTP. + assert_eq!(http_hostname("https://localhost/jwks"), None); + } } diff --git a/crates/forge-core/src/workflow/mod.rs b/crates/forge-core/src/workflow/mod.rs index a689b425..02091188 100644 --- a/crates/forge-core/src/workflow/mod.rs +++ b/crates/forge-core/src/workflow/mod.rs @@ -6,6 +6,6 @@ mod traits; pub use context::{CompensationHandler, StepState, WorkflowContext}; pub use events::{NoOpEventSender, WorkflowEventSender, serialize_payload}; -pub use step::{Step, StepBuilder, StepConfig, StepResult, StepStatus}; +pub use step::StepStatus; pub use suspend::{SuspendReason, WorkflowEvent}; pub use traits::{ForgeWorkflow, WorkflowDefStatus, WorkflowInfo, WorkflowStatus}; diff --git a/crates/forge-core/src/workflow/step.rs b/crates/forge-core/src/workflow/step.rs index 56518200..2d78cb44 100644 --- a/crates/forge-core/src/workflow/step.rs +++ b/crates/forge-core/src/workflow/step.rs @@ -1,15 +1,4 @@ -use std::future::Future; -use std::pin::Pin; use std::str::FromStr; -use std::sync::Arc; -use std::time::Duration; - -use serde::{Serialize, de::DeserializeOwned}; - -use crate::Result; - -/// Type alias for compensation function to reduce complexity. -type CompensateFn<'a, T, C> = Arc Pin> + Send + Sync + 'a>; /// Step execution status. #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -25,6 +14,9 @@ pub enum StepStatus { Failed, /// Step compensation ran. Compensated, + /// Step compensation handler ran but failed; manual remediation may be + /// required for any side effects of the original step. + CompensationFailed, /// Step was skipped. Skipped, /// Step is waiting (suspended). @@ -40,6 +32,7 @@ impl StepStatus { Self::Completed => "completed", Self::Failed => "failed", Self::Compensated => "compensated", + Self::CompensationFailed => "compensation_failed", Self::Skipped => "skipped", Self::Waiting => "waiting", } @@ -67,6 +60,7 @@ impl FromStr for StepStatus { "completed" => Ok(Self::Completed), "failed" => Ok(Self::Failed), "compensated" => Ok(Self::Compensated), + "compensation_failed" => Ok(Self::CompensationFailed), "skipped" => Ok(Self::Skipped), "waiting" => Ok(Self::Waiting), _ => Err(ParseStepStatusError(s.to_string())), @@ -74,220 +68,22 @@ impl FromStr for StepStatus { } } -/// Result of a step execution. -#[derive(Debug, Clone)] -pub struct StepResult { - /// Step name. - pub name: String, - /// Step status. - pub status: StepStatus, - /// Step result (if completed). - pub value: Option, - /// Error message (if failed). - pub error: Option, -} - -/// A workflow step definition. -pub struct Step { - /// Step name. - pub name: String, - /// Step result type. - _marker: std::marker::PhantomData, -} - -impl Step { - /// Create a new step. - pub fn new(name: impl Into) -> Self { - Self { - name: name.into(), - _marker: std::marker::PhantomData, - } - } -} - -/// Builder for configuring and executing a step. -pub struct StepBuilder<'a, T, F, C> -where - T: Serialize + DeserializeOwned + Send + 'static, - F: Future> + Send + 'a, - C: Future> + Send + 'a, -{ - name: String, - run_fn: Option F + Send + 'a>>>, - compensate_fn: Option>, - timeout: Option, - retry_count: u32, - retry_delay: Duration, - optional: bool, - _marker: std::marker::PhantomData<(T, F, C)>, -} - -impl<'a, T, F, C> StepBuilder<'a, T, F, C> -where - T: Serialize + DeserializeOwned + Send + Clone + 'static, - F: Future> + Send + 'a, - C: Future> + Send + 'a, -{ - /// Create a new step builder. - pub fn new(name: impl Into) -> Self { - Self { - name: name.into(), - run_fn: None, - compensate_fn: None, - timeout: None, - retry_count: 0, - retry_delay: Duration::from_secs(1), - optional: false, - _marker: std::marker::PhantomData, - } - } - - /// Set the step execution function. - pub fn run(mut self, f: RF) -> Self - where - RF: FnOnce() -> F + Send + 'a, - { - self.run_fn = Some(Box::pin(f)); - self - } - - /// Set the compensation function. - /// - /// # Warning - /// - /// Compensation handlers are in-memory closures. They do **not** survive - /// process restarts. If the workflow suspends (via `ctx.sleep()` or - /// `ctx.wait_for_event()`) and the process restarts before the workflow - /// completes, registered compensation handlers are lost. The executor - /// detects this and fails the workflow with a message requiring manual - /// remediation. - pub fn compensate(mut self, f: CF) -> Self - where - CF: Fn(T) -> Pin> + Send + Sync + 'a, - { - self.compensate_fn = Some(Arc::new(f)); - self - } - - /// Set step timeout. - pub fn timeout(mut self, duration: Duration) -> Self { - self.timeout = Some(duration); - self - } - - /// Configure retry behavior. - pub fn retry(mut self, count: u32, delay: Duration) -> Self { - self.retry_count = count; - self.retry_delay = delay; - self - } - - /// Mark the step as optional (failure won't trigger compensation). - pub fn optional(mut self) -> Self { - self.optional = true; - self - } - - /// Get step name. - pub fn name(&self) -> &str { - &self.name - } - - /// Check if step is optional. - pub fn is_optional(&self) -> bool { - self.optional - } - - /// Get retry count. - pub fn retry_count(&self) -> u32 { - self.retry_count - } - - /// Get retry delay. - pub fn retry_delay(&self) -> Duration { - self.retry_delay - } - - /// Get timeout. - pub fn get_timeout(&self) -> Option { - self.timeout - } -} - -/// Configuration for a step (without closures, for storage). -#[derive(Debug, Clone)] -pub struct StepConfig { - /// Step name. - pub name: String, - /// Step timeout. - pub timeout: Option, - /// Retry count. - pub retry_count: u32, - /// Retry delay. - pub retry_delay: Duration, - /// Whether the step is optional. - pub optional: bool, - /// Whether the step has a compensation function. - pub has_compensation: bool, -} - -impl Default for StepConfig { - fn default() -> Self { - Self { - name: String::new(), - timeout: None, - retry_count: 0, - retry_delay: Duration::from_secs(1), - optional: false, - has_compensation: false, - } - } -} - #[cfg(test)] -#[allow(clippy::unwrap_used, clippy::indexing_slicing)] +#[allow(clippy::unwrap_used)] mod tests { use super::*; - #[test] - fn test_step_status_conversion() { - assert_eq!(StepStatus::Pending.as_str(), "pending"); - assert_eq!(StepStatus::Running.as_str(), "running"); - assert_eq!(StepStatus::Completed.as_str(), "completed"); - assert_eq!(StepStatus::Failed.as_str(), "failed"); - assert_eq!(StepStatus::Compensated.as_str(), "compensated"); - - assert_eq!("pending".parse::(), Ok(StepStatus::Pending)); - assert_eq!("completed".parse::(), Ok(StepStatus::Completed)); - } - - #[test] - fn test_step_config_default() { - let config = StepConfig::default(); - assert!(config.name.is_empty()); - assert!(!config.optional); - assert_eq!(config.retry_count, 0); - } - - #[test] - fn step_status_as_str_covers_all_variants() { - assert_eq!(StepStatus::Pending.as_str(), "pending"); - assert_eq!(StepStatus::Running.as_str(), "running"); - assert_eq!(StepStatus::Completed.as_str(), "completed"); - assert_eq!(StepStatus::Failed.as_str(), "failed"); - assert_eq!(StepStatus::Compensated.as_str(), "compensated"); - assert_eq!(StepStatus::Skipped.as_str(), "skipped"); - assert_eq!(StepStatus::Waiting.as_str(), "waiting"); - } - #[test] fn step_status_parse_roundtrips_every_variant() { + // StepStatus is persisted to and read back from the DB (executor.rs, + // state.rs), so as_str() and FromStr must stay inverses for every variant. for status in [ StepStatus::Pending, StepStatus::Running, StepStatus::Completed, StepStatus::Failed, StepStatus::Compensated, + StepStatus::CompensationFailed, StepStatus::Skipped, StepStatus::Waiting, ] { @@ -304,46 +100,4 @@ mod tests { // Display must echo the bad value so logs pinpoint the typo. assert!(err.to_string().contains("garbage")); } - - #[test] - fn step_constructor_records_name() { - let s: Step = Step::new("send_email"); - assert_eq!(s.name, "send_email"); - } - - type NoFut = Pin> + Send + 'static>>; - type NoComp = Pin> + Send + 'static>>; - - fn fresh_builder<'a>() -> StepBuilder<'a, u32, NoFut, NoComp> { - StepBuilder::new("noop") - } - - #[test] - fn step_builder_defaults() { - let b = fresh_builder(); - assert_eq!(b.name(), "noop"); - assert!(!b.is_optional()); - assert_eq!(b.retry_count(), 0); - assert_eq!(b.retry_delay(), Duration::from_secs(1)); - assert!(b.get_timeout().is_none()); - } - - #[test] - fn step_builder_optional_flag_flips() { - let b = fresh_builder().optional(); - assert!(b.is_optional()); - } - - #[test] - fn step_builder_retry_sets_count_and_delay() { - let b = fresh_builder().retry(3, Duration::from_millis(250)); - assert_eq!(b.retry_count(), 3); - assert_eq!(b.retry_delay(), Duration::from_millis(250)); - } - - #[test] - fn step_builder_timeout_setter() { - let b = fresh_builder().timeout(Duration::from_secs(5)); - assert_eq!(b.get_timeout(), Some(Duration::from_secs(5))); - } } diff --git a/crates/forge-harness/src/app.rs b/crates/forge-harness/src/app.rs index 0b471e4a..0bf91cd1 100644 --- a/crates/forge-harness/src/app.rs +++ b/crates/forge-harness/src/app.rs @@ -26,8 +26,16 @@ use crate::error::HarnessError; use crate::sse::HarnessSession; use crate::{Result, sse}; -/// Minimum bytes for an HMAC JWT secret (matches the framework's startup validator). -const TEST_JWT_SECRET: &str = "forge-harness-test-jwt-secret-please-rotate-32b"; +/// Generate a fresh 64-hex-char (256-bit) JWT secret for this harness instance. +/// A new secret per `HarnessAppBuilder::new` avoids token reuse between +/// independently-running tests and removes any constant a downstream consumer +/// could lean on. +fn random_jwt_secret() -> String { + let mut s = String::with_capacity(64); + s.push_str(&Uuid::new_v4().simple().to_string()); + s.push_str(&Uuid::new_v4().simple().to_string()); + s +} /// Builder for the in-process harness app. Use this to override defaults /// before starting; the simple path is `HarnessApp::start(test_name)`. @@ -37,6 +45,8 @@ pub struct HarnessAppBuilder { jwt_secret: String, extra_internal_sql: Vec, cors_enabled: bool, + rate_limiter: Option>, + auth_config: Option, } impl HarnessAppBuilder { @@ -46,9 +56,11 @@ impl HarnessAppBuilder { Self { test_name: test_name.into(), migrations_dir: None, - jwt_secret: TEST_JWT_SECRET.to_string(), + jwt_secret: random_jwt_secret(), extra_internal_sql: Vec::new(), cors_enabled: false, + rate_limiter: None, + auth_config: None, } } @@ -79,6 +91,32 @@ impl HarnessAppBuilder { self } + /// Install a rate limiter backend. By default the harness wires a + /// [`forge_runtime::StrictRateLimiter`] backed by the test database so the + /// gateway's RPC/signal/login throttle paths are exercised in tests. Pass + /// a custom backend (or pass [`Self::no_rate_limiter`]) to override. + pub fn with_rate_limiter( + mut self, + rate_limiter: Arc, + ) -> Self { + self.rate_limiter = Some(rate_limiter); + self + } + + /// Replace the gateway's `AuthConfig`. Tests that need to exercise the + /// RS256/JWKS path, legacy secret rotation, or `required_claims` overrides + /// can supply a fully-formed config here. The default builds an HS256 + /// config with the harness's auto-generated secret. + /// + /// When this is set, [`HarnessApp::issue_token`] will only work if the + /// supplied config still has an HMAC secret on it (because `HmacTokenIssuer` + /// needs one). For pure-JWKS test paths, mint tokens via your own helper + /// and call them on the harness via `client.with_token(...)`. + pub fn with_auth_config(mut self, auth_config: AuthConfig) -> Self { + self.auth_config = Some(auth_config); + self + } + /// Boot the harness app: provision the DB, run migrations, register /// every `#[forge::*]` handler available via inventory, wire up the /// gateway, worker, reactor, and bind on `127.0.0.1:0`. @@ -118,17 +156,25 @@ impl HarnessApp { .migrations_dir .unwrap_or_else(|| PathBuf::from(".harness-no-user-migrations")); - let db = forge_core::testing::IsolatedTestDb::setup( - &builder.test_name, - &internal_sql, - &migrations_dir, - ) - .await - .map_err(HarnessError::Forge)?; - + // Order matches the documented contract: system schema -> extra_sql + // (test fixture rows) -> user migrations. Doing it manually rather + // than via `IsolatedTestDb::setup` keeps that ordering explicit. + let base = forge_core::testing::TestDatabase::from_env() + .await + .map_err(HarnessError::Forge)?; + let db = base + .isolated(&builder.test_name) + .await + .map_err(HarnessError::Forge)?; + db.run_sql(&internal_sql) + .await + .map_err(HarnessError::Forge)?; for sql in &builder.extra_internal_sql { db.run_sql(sql).await.map_err(HarnessError::Forge)?; } + db.migrate(&migrations_dir) + .await + .map_err(HarnessError::Forge)?; let pool = db.pool().clone(); let database = Database::from_pool(pool.clone()); @@ -143,7 +189,8 @@ impl HarnessApp { webhooks: WebhookRegistry::new(), mcp_tools: McpToolRegistry::new(), }; - forge::auto_register_all(&mut registries); + forge::auto_register_all(&mut registries) + .map_err(|e| HarnessError::Setup(format!("auto_register_all failed: {e}")))?; // Workflow runs refuse to start unless the (name, version, signature) row // exists. Same upsert logic the production runtime runs at boot. @@ -191,11 +238,20 @@ impl HarnessApp { ); let workflow_dispatcher: Arc = workflow_executor.clone(); - let auth_config = AuthConfig::with_secret(builder.jwt_secret.clone()); + let auth_config = builder + .auth_config + .clone() + .unwrap_or_else(|| AuthConfig::with_secret(builder.jwt_secret.clone())); + // Issuing tokens via `issue_token` requires an HMAC secret. If the + // operator overrode `AuthConfig` to point at JWKS only, mint your own + // tokens and attach via `client.with_token(...)`; the issuer below is + // optional and lazily evaluated. let token_issuer: Arc = Arc::new(HmacTokenIssuer::from_config(&auth_config).ok_or_else(|| { HarnessError::setup( - "HmacTokenIssuer::from_config returned None; JWT secret missing or empty", + "HmacTokenIssuer::from_config returned None; \ + either set jwt_secret on the AuthConfig or skip issue_token() and \ + attach pre-minted tokens via client.with_token()", ) })?); @@ -203,11 +259,18 @@ impl HarnessApp { port: 0, auth: auth_config.clone(), cors_enabled: builder.cors_enabled, - security_headers: false, + security_headers: true, request_timeout_secs: 30, ..GatewayConfig::default() }; + // Default rate limiter: the strict PG-backed token bucket. Production + // parity is the point — tests that regress the rate-limit path now + // fail rather than silently pass. + let rate_limiter: Arc = builder + .rate_limiter + .unwrap_or_else(|| Arc::new(forge_runtime::StrictRateLimiter::new(pool.clone()))); + let gateway = GatewayServer::new( gateway_config.clone(), registries.functions.clone(), @@ -215,7 +278,8 @@ impl HarnessApp { notify_bus.clone(), ) .with_job_dispatcher(job_dispatcher.clone()) - .with_workflow_dispatcher(workflow_dispatcher.clone()); + .with_workflow_dispatcher(workflow_dispatcher.clone()) + .with_rate_limiter(rate_limiter); let reactor = gateway.reactor(); reactor @@ -412,6 +476,28 @@ impl HarnessApp { self.token_issuer.sign(&claims).map_err(HarnessError::Forge) } + /// Issue a JWT after letting the caller mutate the `Claims::builder`. Use + /// this for tests that need expired tokens (`duration_secs(-1)`), custom + /// claims (`tenant_id`, custom roles), or wrong-secret tokens (via + /// `with_auth_config` paired with a token signed by the test). + /// + /// Example: + /// ```ignore + /// let expired = app.issue_token_with_claims(|b| { + /// b.user_id(user_id).duration_secs(-3600) + /// })?; + /// ``` + pub fn issue_token_with_claims(&self, build: F) -> Result + where + F: FnOnce(forge_core::ClaimsBuilder) -> forge_core::ClaimsBuilder, + { + let builder = build(forge_core::Claims::builder()); + let claims = builder + .build() + .map_err(|e| HarnessError::setup(format!("build claims: {e}")))?; + self.token_issuer.sign(&claims).map_err(HarnessError::Forge) + } + /// Open a long-lived SSE session for the given token (or anonymous). The /// returned session lets you subscribe to functions and read updates as /// the reactor pushes them. @@ -442,7 +528,16 @@ impl HarnessApp { impl Drop for HarnessApp { fn drop(&mut self) { + // Notify cooperative shutdown first so well-behaved tasks (gateway, + // worker, scheduler, notify bus) exit on their own. self.signal_shutdown(); + // Then abort any handle that's still alive so they stop touching the + // pool before `IsolatedTestDb`'s own Drop fires `DROP DATABASE`. If + // we left them running, a worker mid-poll would race the drop and + // either panic on a closed pool or block the database termination. + for handle in self.handles.drain(..) { + handle.abort(); + } } } diff --git a/crates/forge-harness/src/client.rs b/crates/forge-harness/src/client.rs index 8f71eb9a..f120a02f 100644 --- a/crates/forge-harness/src/client.rs +++ b/crates/forge-harness/src/client.rs @@ -76,9 +76,9 @@ impl HarnessClient { A: serde::Serialize, R: DeserializeOwned, { - let envelope = self.call_raw(function, args).await?; + let (status, envelope) = self.call_raw_with_status(function, args).await?; if !envelope.success { - return Err(envelope_to_error(envelope)); + return Err(envelope_to_error(envelope, status)); } let data = envelope.data.unwrap_or(serde_json::Value::Null); Ok(serde_json::from_value(data)?) @@ -89,6 +89,21 @@ impl HarnessClient { /// in `RpcEnvelope.error.code` and the response status code is folded in /// when the envelope is missing (e.g. middleware rejection). pub async fn call_raw(&self, function: &str, args: A) -> Result + where + A: serde::Serialize, + { + let (_status, envelope) = self.call_raw_with_status(function, args).await?; + Ok(envelope) + } + + /// Same as [`call_raw`] but also returns the HTTP status code. Used + /// internally so error envelopes can carry the real status (401 vs 403 vs + /// 500) instead of collapsing to 0. + pub async fn call_raw_with_status( + &self, + function: &str, + args: A, + ) -> Result<(u16, RpcEnvelope), HarnessError> where A: serde::Serialize, { @@ -112,17 +127,20 @@ impl HarnessClient { // Empty body with non-2xx status: synthesize an envelope so callers // get a uniform error type rather than a serde failure on `null`. if bytes.is_empty() && !status.is_success() { - return Ok(RpcEnvelope { - success: false, - data: None, - error: Some(RpcEnvelopeError { - code: format!("HTTP_{}", status.as_u16()), - message: status.canonical_reason().unwrap_or("unknown").to_string(), - retry_after_secs: None, - details: None, - }), - request_id: None, - }); + return Ok(( + status.as_u16(), + RpcEnvelope { + success: false, + data: None, + error: Some(RpcEnvelopeError { + code: format!("HTTP_{}", status.as_u16()), + message: status.canonical_reason().unwrap_or("unknown").to_string(), + retry_after_secs: None, + details: None, + }), + request_id: None, + }, + )); } let envelope: RpcEnvelope = @@ -135,7 +153,7 @@ impl HarnessClient { ), status: status.as_u16(), })?; - Ok(envelope) + Ok((status.as_u16(), envelope)) } /// Invoke an RPC and assert that it failed. Returns the error envelope @@ -148,7 +166,7 @@ impl HarnessClient { where A: serde::Serialize, { - let envelope = self.call_raw(function, args).await?; + let (status, envelope) = self.call_raw_with_status(function, args).await?; if envelope.success { return Err(HarnessError::Rpc { code: "UNEXPECTED_SUCCESS".to_string(), @@ -156,28 +174,28 @@ impl HarnessClient { "expected {function} to fail, got success: {:?}", envelope.data ), - status: 200, + status, }); } envelope.error.ok_or_else(|| HarnessError::Rpc { code: "MALFORMED_RESPONSE".to_string(), message: "error envelope without `error` field".to_string(), - status: 0, + status, }) } } -fn envelope_to_error(envelope: RpcEnvelope) -> HarnessError { +fn envelope_to_error(envelope: RpcEnvelope, status: u16) -> HarnessError { match envelope.error { Some(err) => HarnessError::Rpc { code: err.code, message: err.message, - status: 0, + status, }, None => HarnessError::Rpc { code: "MALFORMED_RESPONSE".to_string(), message: "success=false but no error".to_string(), - status: 0, + status, }, } } diff --git a/crates/forge-harness/src/error.rs b/crates/forge-harness/src/error.rs index 49d4b21f..25ae043c 100644 --- a/crates/forge-harness/src/error.rs +++ b/crates/forge-harness/src/error.rs @@ -17,7 +17,7 @@ pub enum HarnessError { #[error("sqlx error: {0}")] Sqlx(#[from] sqlx::Error), - #[error("rpc call failed: code={code} message={message}")] + #[error("rpc call failed: status={status} code={code} message={message}")] Rpc { code: String, message: String, diff --git a/crates/forge-harness/src/sse.rs b/crates/forge-harness/src/sse.rs index bc3cbee5..0c70a33e 100644 --- a/crates/forge-harness/src/sse.rs +++ b/crates/forge-harness/src/sse.rs @@ -1,3 +1,4 @@ +use std::collections::{HashMap, VecDeque}; use std::time::Duration; use eventsource_stream::Eventsource; @@ -44,6 +45,17 @@ pub struct HarnessSession { session_id: String, session_secret: String, events: Mutex, + /// Per-target backlog so events the test doesn't currently care about + /// aren't silently dropped. Keys are the wire-level targets ("sub:foo", + /// "job:abc", "wf:xyz"). A test that waits on "job:a" first and then + /// "wf:b" sees the "wf:b" push even if it arrived during the "job:a" wait. + buffered: Mutex>>, +} + +#[derive(Debug, Clone)] +enum BufferedEvent { + Update(serde_json::Value), + Error { code: String, message: String }, } type EventStream = @@ -57,17 +69,15 @@ impl HarnessSession { base_url: String, token: Option, ) -> Result { - let url = if let Some(t) = &token { - format!("{base_url}/_api/events?token={t}") - } else { - format!("{base_url}/_api/events") - }; - - let resp = http - .get(&url) - .header("Accept", "text/event-stream") - .send() - .await?; + // Auth via Authorization header rather than ?token=…, so a regression + // that closes the query-string loophole on the server side doesn't + // falsely fail every harness session test. + let url = format!("{base_url}/_api/events"); + let mut req = http.get(&url).header("Accept", "text/event-stream"); + if let Some(t) = &token { + req = req.bearer_auth(t); + } + let resp = req.send().await?; if !resp.status().is_success() { return Err(HarnessError::sse(format!( "SSE connect failed: status={}", @@ -107,6 +117,7 @@ impl HarnessSession { session_id, session_secret, events: Mutex::new(events), + buffered: Mutex::new(HashMap::new()), }) } @@ -247,10 +258,32 @@ impl HarnessSession { Ok(body.get("data").cloned().unwrap_or(serde_json::Value::Null)) } + /// Explicitly close the SSE stream, releasing the reqwest connection and + /// signaling the gateway to drop the session. + /// + /// Called automatically by `Drop`, but tests that open many sessions per + /// test can invoke this proactively to keep the gateway's session table + /// small. Idempotent — safe to call multiple times. + pub async fn close(&self) { + // Replacing the stream with an empty one drops the underlying reqwest + // body and the gateway sees the TCP connection close. We don't await + // the gateway's cleanup; the SessionServer evicts the row on disconnect. + let mut events = self.events.lock().await; + *events = Box::pin(futures_util::stream::empty()); + } + /// Read the next SSE event from the stream within the given budget. + /// + /// The lock is released between events rather than held for the full + /// timeout, so concurrent tasks sharing a session can make progress. pub async fn next_event(&self, within: Duration) -> Result { - let mut events = self.events.lock().await; - match timeout(within, events.next()).await { + // Lock just long enough to poll the stream once; releasing it + // between polls lets another task interleave. + let poll = { + let mut events = self.events.lock().await; + timeout(within, events.next()).await + }; + match poll { Ok(Some(Ok(ev))) => Ok(ev), Ok(Some(Err(e))) => Err(e), Ok(None) => Err(HarnessError::sse("SSE stream ended")), @@ -258,36 +291,65 @@ impl HarnessSession { } } - /// Read events until we see an `Update` for the given target. Other - /// events are buffered in the stream order is preserved on the next - /// `next_event` call (we drop them). Use this in tests that only care - /// about a specific subscription's payload. + /// Read events until we see an `Update` for the given target. Events for + /// other targets are buffered (per-target FIFO) so a subsequent call for + /// a different target still sees pushes that arrived during this wait. pub async fn next_update_for( &self, target: &str, within: Duration, ) -> Result { + // First, drain any buffered event for this target. + if let Some(ev) = self.pop_buffered(target).await { + return match ev { + BufferedEvent::Update(p) => Ok(p), + BufferedEvent::Error { code, message } => Err(HarnessError::sse(format!( + "error for target {target}: {code} {message}" + ))), + }; + } + let deadline = tokio::time::Instant::now() + within; loop { let remaining = deadline .checked_duration_since(tokio::time::Instant::now()) .ok_or_else(|| HarnessError::timeout(format!("update for {target}")))?; match self.next_event(remaining).await? { - SseEvent::Update { target: t, payload } if t == target => return Ok(payload), + SseEvent::Update { target: t, payload } => { + if t == target { + return Ok(payload); + } + self.push_buffered(t, BufferedEvent::Update(payload)).await; + } SseEvent::Error { target: t, code, message, - } if t == target => { - return Err(HarnessError::sse(format!( - "error for target {t}: {code} {message}" - ))); + } => { + if t == target { + return Err(HarnessError::sse(format!( + "error for target {t}: {code} {message}" + ))); + } + self.push_buffered(t, BufferedEvent::Error { code, message }) + .await; } _ => continue, } } } + async fn pop_buffered(&self, target: &str) -> Option { + let mut buf = self.buffered.lock().await; + let q = buf.get_mut(target)?; + q.pop_front() + } + + async fn push_buffered(&self, target: String, ev: BufferedEvent) { + let mut buf = self.buffered.lock().await; + buf.entry(target).or_default().push_back(ev); + } + /// Wait for a reactor push to the query subscription `id` — the id passed /// to [`HarnessSession::subscribe`]. Hides the wire-level `sub:` target /// prefix the gateway adds to query updates. @@ -320,6 +382,20 @@ impl HarnessSession { } } +impl Drop for HarnessSession { + /// Best-effort close: replace the stream with an empty one so the + /// reqwest body and underlying TCP connection are dropped synchronously. + /// The gateway's SessionServer reaps the row on the next cleanup pass. + fn drop(&mut self) { + // Drain a blocking try_lock if available; if a task still holds the + // events mutex, the stream will be dropped when that task releases + // it. We don't .await here — Drop is sync. + if let Ok(mut events) = self.events.try_lock() { + *events = Box::pin(futures_util::stream::empty()); + } + } +} + fn parse_sse_event(ev: &eventsource_stream::Event) -> SseEvent { match ev.event.as_str() { "connected" => { diff --git a/crates/forge-harness/tests/auth.rs b/crates/forge-harness/tests/auth.rs index b91d0921..d40ac20c 100644 --- a/crates/forge-harness/tests/auth.rs +++ b/crates/forge-harness/tests/auth.rs @@ -1,4 +1,3 @@ -#![cfg(feature = "testcontainers")] //! Authentication and authorization scenarios. //! //! Every frontend client is one of three callers: anonymous, an authenticated @@ -9,14 +8,32 @@ //! hit a 403 on the admin page": same gateway, same JWT verification, same //! `require_auth` path, same RPC envelope a browser client would consume. +/// Sentinel test so `cargo test -p forge-harness` (without `--features +/// testcontainers`) doesn't silently report "0 tests passed" and lull a +/// contributor into thinking they ran the scenario suite. Always passes; +/// its job is to print the hint. +#[test] +fn ensure_testcontainers_feature_enabled() { + eprintln!( + "forge-harness auth scenarios are gated on `--features testcontainers`. \ + Re-run with `cargo test -p forge-harness --features testcontainers` \ + to exercise authentication paths against a real Postgres." + ); +} + +#[cfg(feature = "testcontainers")] +#[path = "common/mod.rs"] mod common; +#[cfg(feature = "testcontainers")] use common::{Note, start_app}; +#[cfg(feature = "testcontainers")] use uuid::Uuid; /// A private query must reject an anonymous caller at the gateway — before the /// handler body runs — with the `UNAUTHORIZED` envelope a client turns into a /// 401. If this regressed to a success the whole private surface would leak. +#[cfg(feature = "testcontainers")] #[tokio::test] async fn private_query_rejects_anonymous_caller() { let app = start_app("auth_anon_query").await; @@ -37,6 +54,7 @@ async fn private_query_rejects_anonymous_caller() { /// The same gate applies to a private mutation: an anonymous caller is turned /// away before any write is attempted. +#[cfg(feature = "testcontainers")] #[tokio::test] async fn private_mutation_rejects_anonymous_caller() { let app = start_app("auth_anon_mutation").await; @@ -63,6 +81,7 @@ async fn private_mutation_rejects_anonymous_caller() { /// stamps `owner_id` from the JWT subject, and the query filters by it — so a /// regression on either side (a dropped `WHERE`, a mis-read subject claim) /// surfaces as one user seeing the other's rows. +#[cfg(feature = "testcontainers")] #[tokio::test] async fn notes_stay_isolated_between_users() { let app = start_app("auth_user_isolation").await; @@ -135,6 +154,7 @@ async fn notes_stay_isolated_between_users() { /// A role-gated handler must reject an authenticated caller who lacks the /// role with `FORBIDDEN` — distinct from the `UNAUTHORIZED` an anonymous /// caller gets. Authentication alone is not authorization. +#[cfg(feature = "testcontainers")] #[tokio::test] async fn role_gated_query_rejects_missing_role() { let app = start_app("auth_role_missing").await; @@ -155,9 +175,92 @@ async fn role_gated_query_rejects_missing_role() { app.shutdown().await.expect("shutdown"); } +/// An expired JWT must be rejected with `UNAUTHORIZED` before the handler +/// runs. If this regressed to a success the gateway would honor any token +/// whose signature happens to verify, regardless of `exp`. +#[cfg(feature = "testcontainers")] +#[tokio::test] +async fn expired_token_is_rejected_as_unauthorized() { + let app = start_app("auth_expired_token").await; + + // duration_secs(-3600) issues a token whose `exp` is one hour in the past. + let user_id = Uuid::new_v4(); + let expired = app + .issue_token_with_claims(|b| b.user_id(user_id).duration_secs(-3600)) + .expect("issue expired token"); + + let client = app.client().with_token(expired); + let error = client + .expect_error("harness_my_notes", ()) + .await + .expect("a call with an expired token must fail, not succeed"); + assert_eq!( + error.code, "UNAUTHORIZED", + "expired tokens must carry UNAUTHORIZED, saw: {error:?}", + ); + + app.shutdown().await.expect("shutdown"); +} + +/// A malformed bearer token (not even three dotted segments) must be rejected +/// at the gateway, not silently treated as anonymous. +#[cfg(feature = "testcontainers")] +#[tokio::test] +async fn malformed_token_is_rejected_as_unauthorized() { + let app = start_app("auth_malformed_token").await; + + let client = app.client().with_token("this-is-not-a-jwt"); + let error = client + .expect_error("harness_my_notes", ()) + .await + .expect("a call with a malformed token must fail, not succeed"); + assert_eq!( + error.code, "UNAUTHORIZED", + "malformed tokens must carry UNAUTHORIZED, saw: {error:?}", + ); + + app.shutdown().await.expect("shutdown"); +} + +/// A token signed with a different secret must be rejected. This is the +/// signature-verification path: a regression that disabled verify would let +/// any attacker mint tokens against any deployment. +#[cfg(feature = "testcontainers")] +#[tokio::test] +async fn wrong_secret_token_is_rejected_as_unauthorized() { + use forge_core::TokenIssuer; + use forge_runtime::gateway::HmacTokenIssuer; + let app = start_app("auth_wrong_secret").await; + + // Mint a token using an unrelated secret. Same shape as the harness's + // tokens; only the HMAC over header+payload differs. + let attacker_secret = "totally-different-secret-not-used-by-the-harness-instance"; + let attacker_cfg = forge_runtime::gateway::AuthConfig::with_secret(attacker_secret.to_string()); + let attacker_issuer = HmacTokenIssuer::from_config(&attacker_cfg).expect("issuer"); + let claims = forge_core::Claims::builder() + .user_id(Uuid::new_v4()) + .duration_secs(3600) + .build() + .expect("claims"); + let forged = attacker_issuer.sign(&claims).expect("sign"); + + let client = app.client().with_token(forged); + let error = client + .expect_error("harness_my_notes", ()) + .await + .expect("a call with a wrong-secret token must fail, not succeed"); + assert_eq!( + error.code, "UNAUTHORIZED", + "wrong-secret tokens must carry UNAUTHORIZED, saw: {error:?}", + ); + + app.shutdown().await.expect("shutdown"); +} + /// The other half of the role gate: a caller holding the role passes, and the /// handler runs. The count reflects a note the same caller just created, which /// proves the request reached the body rather than short-circuiting. +#[cfg(feature = "testcontainers")] #[tokio::test] async fn role_gated_query_admits_present_role() { let app = start_app("auth_role_present").await; diff --git a/crates/forge-harness/tests/common/mod.rs b/crates/forge-harness/tests/common/mod.rs index e2c5c921..0a17b6b7 100644 --- a/crates/forge-harness/tests/common/mod.rs +++ b/crates/forge-harness/tests/common/mod.rs @@ -222,7 +222,7 @@ pub async fn harness_get_counter(ctx: &QueryContext, name: String) -> Result Result { let mut conn = ctx.conn().await?; sqlx::query_as::<_, Counter>( @@ -239,7 +239,7 @@ pub async fn harness_bump_counter(ctx: &MutationContext, input: BumpInput) -> Re /// Insert a widget row. Touches `harness_widgets` only — used to prove a /// mutation on an unrelated table does NOT invalidate a counter subscription. -#[forge::mutation(auth = "none")] +#[forge::mutation(auth = "none", tables("harness_widgets"), scope = "global")] pub async fn harness_add_widget(ctx: &MutationContext, name: String) -> Result { let mut conn = ctx.conn().await?; let row: (Uuid,) = @@ -250,6 +250,26 @@ pub async fn harness_add_widget(ctx: &MutationContext, name: String) -> Result Result { + let mut conn = ctx.conn().await?; + let _: (Uuid,) = sqlx::query_as("INSERT INTO harness_widgets (name) VALUES ($1) RETURNING id") + .bind(&name) + .fetch_one(&mut conn) + .await?; + // Release the tx connection before dispatch_job re-acquires it. + drop(conn); + ctx.dispatch_job("harness_run_job", RunJobInput { steps: 1 }) + .await?; + Err(ForgeError::internal( + "harness_tx_rollback fails after dispatch by design — widget and job must roll back", + )) +} + /// RPC dispatch result for a job: the gateway returns `{"job_id": "..."}` /// when a job handler is invoked by name. #[derive(Debug, Clone, Deserialize)] @@ -296,6 +316,44 @@ pub async fn harness_failing_job(_ctx: &JobContext, input: RunJobInput) -> Resul ))) } +/// Output of [`harness_retry_job`]: proves the second attempt observed the +/// retry state rather than just "the handler ran twice". +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RetryJobOutput { + pub was_retry: bool, + pub attempt: u32, +} + +/// Fails its first attempt by design, succeeds on the retry. `max_attempts = 2` +/// with a fixed (~1s) backoff exercises retry-then-succeed: the attempt counter +/// must advance and `ctx.is_retry()` must be true on the second run. The audit +/// found this path had no coverage at any layer. +#[forge::job(auth = "none", retry(max_attempts = 2, backoff = "fixed"))] +pub async fn harness_retry_job(ctx: &JobContext, _input: RunJobInput) -> Result { + if !ctx.is_retry() { + return Err(ForgeError::internal( + "harness_retry_job fails the first attempt by design", + )); + } + Ok(RetryJobOutput { + was_retry: ctx.is_retry(), + attempt: ctx.attempt, + }) +} + +/// Always fails. With `max_attempts = 2` it fails attempt 1 (retry scheduled), +/// fails attempt 2 (attempts == max), and must land in `dead_letter` — proving +/// the dead-letter routing terminates rather than retrying forever. +#[forge::job(auth = "none", retry(max_attempts = 2, backoff = "fixed"))] +pub async fn harness_dead_letter_job( + _ctx: &JobContext, + _input: RunJobInput, +) -> Result { + Err(ForgeError::internal( + "harness_dead_letter_job always fails by design", + )) +} + /// RPC dispatch result for a workflow: the gateway returns /// `{"workflow_id": "..."}` when a workflow is invoked by name. #[derive(Debug, Clone, Deserialize)] @@ -382,7 +440,7 @@ pub async fn harness_gated(ctx: &WorkflowContext, input: PipelineInput) -> Resul /// Fire the `harness_gate_opened` event that unblocks a waiting /// [`harness_gated`] run. `workflow_id` is the run id the workflow's RPC /// dispatch returned; it is the event's correlation id. -#[forge::mutation(auth = "none")] +#[forge::mutation(auth = "none", tables("forge_workflow_events"), scope = "global")] pub async fn harness_open_gate(ctx: &MutationContext, workflow_id: String) -> Result { sqlx::query( "INSERT INTO forge_workflow_events (id, event_name, correlation_id, payload) @@ -409,7 +467,7 @@ pub struct Note { /// call form keeps that WHERE clause invisible to the macro's structural scope /// lint, so `scope = "global"` opts out — runtime isolation via `ctx.user_id()` /// is real and is exactly what the auth scenarios assert. -#[forge::query(scope = "global", tables("harness_notes"))] +#[forge::query(tables("harness_notes"), scope = "global")] pub async fn harness_my_notes(ctx: &QueryContext) -> Result> { let owner = ctx.user_id()?; sqlx::query_as::<_, Note>( @@ -426,7 +484,7 @@ pub async fn harness_my_notes(ctx: &QueryContext) -> Result> { /// A bare INSERT has no WHERE clause for the macro's structural scope lint to /// inspect, so `scope = "global"` opts out — the row is still scoped at /// runtime by stamping `owner_id` with `ctx.user_id()`. -#[forge::mutation(scope = "global", tables("harness_notes"))] +#[forge::mutation(tables("harness_notes"), scope = "global")] pub async fn harness_create_note(ctx: &MutationContext, body: String) -> Result { let owner = ctx.user_id()?; let mut conn = ctx.conn().await?; diff --git a/crates/forge-harness/tests/job_retry.rs b/crates/forge-harness/tests/job_retry.rs new file mode 100644 index 00000000..a30362eb --- /dev/null +++ b/crates/forge-harness/tests/job_retry.rs @@ -0,0 +1,153 @@ +//! Job retry + dead-letter scenarios. +//! +//! The effectiveness audit found these completely uncovered: the existing +//! failing-job test uses `max_attempts = 1`, so the retry counter, backoff +//! delay, `ctx.is_retry()` branch, and dead-letter routing never ran at any +//! layer. A wrong attempt comparison or an unapplied backoff would ship green. +//! These drive the real worker end-to-end against Postgres. + +// Asserting on `forge_jobs` uses runtime `sqlx::query_as` (no compile-time DB): +// the harness owns no .sqlx cache, same as `common/mod.rs`. Tests panic to fail. +#![allow(clippy::disallowed_methods)] + +/// Sentinel so `cargo test -p forge-harness` without the feature doesn't report +/// "0 tests passed" and lull a contributor into thinking the suite ran. +#[test] +fn ensure_testcontainers_feature_enabled() { + eprintln!( + "forge-harness job-retry scenarios are gated on `--features testcontainers`. \ + Re-run with `cargo test -p forge-harness --features testcontainers`." + ); +} + +#[cfg(feature = "testcontainers")] +#[path = "common/mod.rs"] +mod common; + +#[cfg(feature = "testcontainers")] +mod scenarios { + use std::time::{Duration, Instant}; + + use uuid::Uuid; + + use super::common::{JobHandle, RetryJobOutput, RunJobInput, drain_job_updates, start_app}; + + /// Worker poll (50ms) + a ~1s fixed retry backoff + reactor round-trips. + const BUDGET: Duration = Duration::from_secs(15); + + /// A job that errors on its first attempt must be retried after the backoff + /// and succeed on the second run, with the attempt counter advanced. + #[tokio::test] + async fn job_retries_once_then_succeeds() { + let app = start_app("job_retry_succeeds").await; + let session = app.open_session(None).await.expect("open sse session"); + + let handle: JobHandle = app + .client() + .call("harness_retry_job", RunJobInput { steps: 0 }) + .await + .expect("dispatch retry job"); + + session + .subscribe_job("r", &handle.job_id) + .await + .expect("subscribe to job"); + + let started = Instant::now(); + let updates = drain_job_updates(&session, "r", BUDGET).await; + let terminal = updates.last().expect("drain yields the terminal update"); + + assert_eq!( + terminal.get("status").and_then(serde_json::Value::as_str), + Some("completed"), + "a job that fails once then succeeds must end completed, saw: {terminal}", + ); + // The fixed backoff is ~1s (±25% jitter → ≥0.75s). A 500ms floor cleanly + // separates "backoff applied" from an instant (~ms) re-run, with margin + // for jitter. This is what regressed when retry delays under 1s were + // truncated to 0 (queue.rs num_seconds bug). + assert!( + started.elapsed() >= Duration::from_millis(500), + "retry must wait out the backoff, not re-run instantly ({:?})", + started.elapsed(), + ); + + let (status, attempts): (String, i32) = + sqlx::query_as("SELECT status, attempts FROM forge_jobs WHERE id = $1") + .bind(Uuid::parse_str(&handle.job_id).expect("job id is a uuid")) + .fetch_one(app.pool()) + .await + .expect("job row"); + assert_eq!(status, "completed"); + assert_eq!( + attempts, 2, + "the attempt counter must advance across the retry" + ); + + let output: RetryJobOutput = serde_json::from_value( + terminal + .get("output") + .cloned() + .expect("completed job carries output"), + ) + .expect("output deserializes to RetryJobOutput"); + assert!( + output.was_retry, + "the second run must observe ctx.is_retry()" + ); + assert_eq!(output.attempt, 2); + + app.shutdown().await.expect("shutdown"); + } + + /// A job that always fails must exhaust exactly `max_attempts` and land in + /// `dead_letter` — not retry forever, and not dead-letter prematurely. + #[tokio::test] + async fn job_dead_letters_after_exhausting_attempts() { + let app = start_app("job_dead_letter").await; + let session = app.open_session(None).await.expect("open sse session"); + + let handle: JobHandle = app + .client() + .call("harness_dead_letter_job", RunJobInput { steps: 0 }) + .await + .expect("dispatch dead-letter job"); + + session + .subscribe_job("d", &handle.job_id) + .await + .expect("subscribe to job"); + + let updates = drain_job_updates(&session, "d", BUDGET).await; + let terminal = updates.last().expect("drain yields the terminal update"); + + assert_eq!( + terminal.get("status").and_then(serde_json::Value::as_str), + Some("dead_letter"), + "a job that always fails must dead-letter after max attempts, saw: {terminal}", + ); + + let (status, attempts): (String, i32) = + sqlx::query_as("SELECT status, attempts FROM forge_jobs WHERE id = $1") + .bind(Uuid::parse_str(&handle.job_id).expect("job id is a uuid")) + .fetch_one(app.pool()) + .await + .expect("job row"); + assert_eq!(status, "dead_letter"); + assert_eq!( + attempts, 2, + "must exhaust exactly max_attempts before dead-lettering" + ); + + let error = terminal + .get("error") + .and_then(serde_json::Value::as_str) + .unwrap_or(""); + assert!( + !error.is_empty(), + "a dead-lettered job must stream a non-empty error, saw: {terminal}", + ); + + app.shutdown().await.expect("shutdown"); + } +} diff --git a/crates/forge-harness/tests/jobs.rs b/crates/forge-harness/tests/jobs.rs index f9b5b9e6..ad878e1a 100644 --- a/crates/forge-harness/tests/jobs.rs +++ b/crates/forge-harness/tests/jobs.rs @@ -1,4 +1,3 @@ -#![cfg(feature = "testcontainers")] //! Background-job scenarios. //! //! A job dispatched over RPC must land on the worker, run to a terminal @@ -7,19 +6,38 @@ //! export, watch the progress bar fill, see it finish": same gateway, same //! worker, same `job:` SSE wire frames a browser client would consume. +/// Sentinel test so `cargo test -p forge-harness` (without `--features +/// testcontainers`) doesn't silently report "0 tests passed" and lull a +/// contributor into thinking they ran the scenario suite. Always passes; +/// its job is to print the hint. +#[test] +fn ensure_testcontainers_feature_enabled() { + eprintln!( + "forge-harness job scenarios are gated on `--features testcontainers`. \ + Re-run with `cargo test -p forge-harness --features testcontainers` \ + to exercise the worker against a real Postgres." + ); +} + +#[cfg(feature = "testcontainers")] +#[path = "common/mod.rs"] mod common; +#[cfg(feature = "testcontainers")] use std::time::Duration; +#[cfg(feature = "testcontainers")] use common::{JobHandle, RunJobInput, drain_job_updates, start_app}; /// Worker poll (50ms) + a few hundred ms of job work + reactor round-trips. /// Generous enough to absorb CI scheduling noise without masking a hang. +#[cfg(feature = "testcontainers")] const JOB_BUDGET: Duration = Duration::from_secs(10); /// The core loop: dispatch a job, subscribe, and watch the worker stream it /// from a non-terminal state through to `completed`, carrying the handler's /// output on the final frame. +#[cfg(feature = "testcontainers")] #[tokio::test] async fn job_runs_and_streams_lifecycle() { let app = start_app("jobs_lifecycle").await; @@ -82,6 +100,7 @@ async fn job_runs_and_streams_lifecycle() { /// A handler that returns `Err` must drive the job to a terminal failure /// state and stream the error message — not silently vanish. +#[cfg(feature = "testcontainers")] #[tokio::test] async fn failing_job_streams_terminal_failure() { let app = start_app("jobs_failure").await; @@ -125,6 +144,7 @@ async fn failing_job_streams_terminal_failure() { /// `subscribe-job` hands back the job's current state synchronously, at /// subscribe time — the equivalent of a store's first value. +#[cfg(feature = "testcontainers")] #[tokio::test] async fn subscribe_job_returns_initial_snapshot() { let app = start_app("jobs_snapshot").await; diff --git a/crates/forge-harness/tests/reactivity.rs b/crates/forge-harness/tests/reactivity.rs index 96ad6b00..3b0a5003 100644 --- a/crates/forge-harness/tests/reactivity.rs +++ b/crates/forge-harness/tests/reactivity.rs @@ -1,4 +1,3 @@ -#![cfg(feature = "testcontainers")] //! Reactivity scenarios. //! //! A write to a reactive table must fan out to every SSE subscriber whose @@ -7,23 +6,43 @@ //! browserless proxy for "open the app in two tabs, change a row, watch both //! update": same gateway, same reactor, same SSE wire format. +/// Sentinel test so `cargo test -p forge-harness` (without `--features +/// testcontainers`) doesn't silently report "0 tests passed" and lull a +/// contributor into thinking they ran the scenario suite. Always passes; +/// its job is to print the hint. +#[test] +fn ensure_testcontainers_feature_enabled() { + eprintln!( + "forge-harness reactivity scenarios are gated on `--features testcontainers`. \ + Re-run with `cargo test -p forge-harness --features testcontainers` \ + to exercise the reactor against a real Postgres." + ); +} + +#[cfg(feature = "testcontainers")] +#[path = "common/mod.rs"] mod common; +#[cfg(feature = "testcontainers")] use std::time::Duration; +#[cfg(feature = "testcontainers")] use common::{BumpInput, Counter, collect_updates, start_app}; /// Generous budget for one reactor round-trip: NOTIFY -> invalidate (<=200ms /// debounce) -> re-execute -> SSE push. Real latency is tens of ms; the slack /// only absorbs CI scheduling noise. +#[cfg(feature = "testcontainers")] const PUSH_BUDGET: Duration = Duration::from_secs(5); /// Window to watch for a push that must NOT happen. Comfortably past the /// 200ms max debounce, short enough to keep the suite quick. +#[cfg(feature = "testcontainers")] const SILENCE_WINDOW: Duration = Duration::from_millis(1200); /// A fresh subscription hands back the current rows synchronously, at /// subscribe time — the equivalent of a SvelteKit store's first value. +#[cfg(feature = "testcontainers")] #[tokio::test] async fn subscribe_returns_initial_snapshot() { let app = start_app("reactivity_initial_snapshot").await; @@ -60,6 +79,7 @@ async fn subscribe_returns_initial_snapshot() { /// The core loop: subscribe, mutate over RPC, observe the reactor push the /// fresh result down the SSE stream. +#[cfg(feature = "testcontainers")] #[tokio::test] async fn single_client_sees_invalidation() { let app = start_app("reactivity_single_client").await; @@ -106,6 +126,7 @@ async fn single_client_sees_invalidation() { /// Two independent SSE sessions subscribed to the same query collapse into a /// single reactor group (dedup by query + args + auth scope), yet one mutation /// still fans out to both. This is the "two browser tabs" scenario. +#[cfg(feature = "testcontainers")] #[tokio::test] async fn two_clients_both_receive_invalidation() { let app = start_app("reactivity_two_clients").await; @@ -160,6 +181,7 @@ async fn two_clients_both_receive_invalidation() { /// A mutation to a different reactive table must not wake a counter /// subscription. To rule out a false pass from a dead stream, we then issue a /// real counter write and confirm that push *does* arrive. +#[cfg(feature = "testcontainers")] #[tokio::test] async fn unrelated_table_mutation_does_not_push() { let app = start_app("reactivity_unrelated_table").await; @@ -218,6 +240,7 @@ async fn unrelated_table_mutation_does_not_push() { /// changes only one arg-set's result must push to that subscription alone: /// the table invalidation re-executes both, but hash comparison suppresses the /// push for the one whose result did not change. +#[cfg(feature = "testcontainers")] #[tokio::test] async fn subscriptions_isolated_by_args() { let app = start_app("reactivity_args_isolation").await; diff --git a/crates/forge-harness/tests/smoke.rs b/crates/forge-harness/tests/smoke.rs index a912457c..a17476c6 100644 --- a/crates/forge-harness/tests/smoke.rs +++ b/crates/forge-harness/tests/smoke.rs @@ -1,72 +1,88 @@ -#![cfg(feature = "testcontainers")] //! Smoke test: boots the harness, hits a real RPC over HTTP, asserts the //! envelope wire-shape. Mirrors what `forge-svelte/client.ts` puts on the //! wire so this proves the harness sees the same bytes a browser would. -use forge::prelude::*; -use forge_harness::HarnessApp; - -#[forge::query(auth = "none", tables("forge_jobs"))] -pub async fn harness_smoke_ping(_ctx: &QueryContext) -> Result { - Ok("pong".to_string()) +/// Sentinel test so `cargo test -p forge-harness` (without `--features +/// testcontainers`) doesn't silently report "0 tests passed" and lull a +/// contributor into thinking they ran the scenario suite. Always passes; +/// its job is to print the hint. +#[test] +fn ensure_testcontainers_feature_enabled() { + eprintln!( + "forge-harness scenario tests are gated on `--features testcontainers`. \ + Re-run with `cargo test -p forge-harness --features testcontainers` \ + to exercise the gateway/worker/reactor against a real Postgres." + ); } -#[forge::query(auth = "none", tables("forge_jobs"))] -pub async fn harness_smoke_echo(_ctx: &QueryContext, message: String) -> Result { - Ok(format!("echo:{message}")) -} +#[cfg(feature = "testcontainers")] +mod scenarios { -#[tokio::test] -async fn smoke_query_returns_value() { - let app = HarnessApp::start("harness_smoke_query_returns_value") - .await - .expect("harness boot"); + use forge::prelude::*; + use forge_harness::HarnessApp; - let result: String = app - .client() - .call("harness_smoke_ping", ()) - .await - .expect("rpc call"); - assert_eq!(result, "pong"); + #[forge::query(auth = "none", tables("forge_jobs"))] + pub async fn harness_smoke_ping(_ctx: &QueryContext) -> Result { + Ok("pong".to_string()) + } - app.shutdown().await.expect("shutdown"); -} + #[forge::query(auth = "none", tables("forge_jobs"))] + pub async fn harness_smoke_echo(_ctx: &QueryContext, message: String) -> Result { + Ok(format!("echo:{message}")) + } -#[tokio::test] -async fn smoke_query_args_round_trip() { - let app = HarnessApp::start("harness_smoke_query_args_round_trip") - .await - .expect("harness boot"); + #[tokio::test] + async fn smoke_query_returns_value() { + let app = HarnessApp::start("harness_smoke_query_returns_value") + .await + .expect("harness boot"); - let result: String = app - .client() - .call( - "harness_smoke_echo", - serde_json::json!({ "message": "hello" }), - ) - .await - .expect("rpc call"); - assert_eq!(result, "echo:hello"); + let result: String = app + .client() + .call("harness_smoke_ping", ()) + .await + .expect("rpc call"); + assert_eq!(result, "pong"); - app.shutdown().await.expect("shutdown"); -} + app.shutdown().await.expect("shutdown"); + } -#[tokio::test] -async fn smoke_unknown_function_returns_error_envelope() { - let app = HarnessApp::start("harness_smoke_unknown_function") - .await - .expect("harness boot"); + #[tokio::test] + async fn smoke_query_args_round_trip() { + let app = HarnessApp::start("harness_smoke_query_args_round_trip") + .await + .expect("harness boot"); - let err = app - .client() - .expect_error("harness_smoke_no_such_function", ()) - .await - .expect("expected error envelope"); - assert!( - !err.code.is_empty(), - "error envelope must carry a code, got: {:?}", - err - ); + let result: String = app + .client() + .call( + "harness_smoke_echo", + serde_json::json!({ "message": "hello" }), + ) + .await + .expect("rpc call"); + assert_eq!(result, "echo:hello"); + + app.shutdown().await.expect("shutdown"); + } + + #[tokio::test] + async fn smoke_unknown_function_returns_error_envelope() { + let app = HarnessApp::start("harness_smoke_unknown_function") + .await + .expect("harness boot"); + + let err = app + .client() + .expect_error("harness_smoke_no_such_function", ()) + .await + .expect("expected error envelope"); + assert!( + !err.code.is_empty(), + "error envelope must carry a code, got: {:?}", + err + ); - app.shutdown().await.expect("shutdown"); + app.shutdown().await.expect("shutdown"); + } } diff --git a/crates/forge-harness/tests/tx_rollback.rs b/crates/forge-harness/tests/tx_rollback.rs new file mode 100644 index 00000000..1864c919 --- /dev/null +++ b/crates/forge-harness/tests/tx_rollback.rs @@ -0,0 +1,68 @@ +//! Transactional integrity. +//! +//! A mutation that writes a row and dispatches a job, then errors, must leave +//! NOTHING behind: both the data row and the buffered job (dispatched on the +//! transaction via the outbox path) roll back together. The audit found the +//! harness only ever committed on success — a commit-on-error or +//! job-written-outside-the-tx regression would silently corrupt data and pass. + +// Asserting on harness_widgets / forge_jobs uses runtime sqlx (no compile-time +// DB); same rationale as common/mod.rs. Tests panic to fail. +#![allow(clippy::disallowed_methods)] + +/// Sentinel so the suite isn't silently empty without `--features testcontainers`. +#[test] +fn ensure_testcontainers_feature_enabled() { + eprintln!( + "forge-harness tx-rollback scenario is gated on `--features testcontainers`. \ + Re-run with `cargo test -p forge-harness --features testcontainers`." + ); +} + +#[cfg(feature = "testcontainers")] +#[path = "common/mod.rs"] +mod common; + +#[cfg(feature = "testcontainers")] +mod scenarios { + use super::common::start_app; + + /// The mutation errors after both the INSERT and the dispatch. Afterwards + /// neither the widget nor the job may exist — proving `execute_transactional` + /// rolls the whole unit back, outbox job included. + #[tokio::test] + async fn transactional_mutation_error_rolls_back_widget_and_job() { + let app = start_app("tx_rollback").await; + let name = "rollback-me-7a3f"; + + let err = app + .client() + .expect_error("harness_tx_rollback", name) + .await + .expect("mutation must fail so the transaction unwinds"); + assert!(!err.code.is_empty(), "error envelope must carry a code"); + + let widgets: (i64,) = + sqlx::query_as("SELECT COUNT(*) FROM harness_widgets WHERE name = $1") + .bind(name) + .fetch_one(app.pool()) + .await + .expect("count widgets"); + assert_eq!( + widgets.0, 0, + "the widget INSERT must roll back when the mutation errors", + ); + + let jobs: (i64,) = + sqlx::query_as("SELECT COUNT(*) FROM forge_jobs WHERE job_type = 'harness_run_job'") + .fetch_one(app.pool()) + .await + .expect("count jobs"); + assert_eq!( + jobs.0, 0, + "the dispatched job must roll back with the transaction (outbox-on-tx)", + ); + + app.shutdown().await.expect("shutdown"); + } +} diff --git a/crates/forge-harness/tests/workflows.rs b/crates/forge-harness/tests/workflows.rs index 713f887f..6b5fc88f 100644 --- a/crates/forge-harness/tests/workflows.rs +++ b/crates/forge-harness/tests/workflows.rs @@ -1,4 +1,3 @@ -#![cfg(feature = "testcontainers")] //! Workflow scenarios. //! //! A workflow dispatched over RPC must run on the worker, stream its state @@ -9,20 +8,39 @@ //! finish": same gateway, same executor, same scheduler, same `wf:` SSE wire //! frames a browser client would consume. +/// Sentinel test so `cargo test -p forge-harness` (without `--features +/// testcontainers`) doesn't silently report "0 tests passed" and lull a +/// contributor into thinking they ran the scenario suite. Always passes; +/// its job is to print the hint. +#[test] +fn ensure_testcontainers_feature_enabled() { + eprintln!( + "forge-harness workflow scenarios are gated on `--features testcontainers`. \ + Re-run with `cargo test -p forge-harness --features testcontainers` \ + to exercise the workflow executor against a real Postgres." + ); +} + +#[cfg(feature = "testcontainers")] +#[path = "common/mod.rs"] mod common; +#[cfg(feature = "testcontainers")] use std::time::Duration; +#[cfg(feature = "testcontainers")] use common::{PipelineInput, PipelineOutput, WorkflowHandle, await_workflow_status, start_app}; /// Worker poll (50ms) + workflow scheduler poll (100ms) + step persistence, a /// 400ms durable sleep, and reactor round-trips. Generous enough to absorb CI /// scheduling noise without masking a hang. +#[cfg(feature = "testcontainers")] const WF_BUDGET: Duration = Duration::from_secs(15); /// The straight-line case: dispatch a three-step workflow, subscribe, and watch /// the executor carry it to `completed` — with every step recorded, in /// declaration order, and the input label round-tripped onto the output. +#[cfg(feature = "testcontainers")] #[tokio::test] async fn linear_workflow_runs_all_steps_to_completion() { let app = start_app("workflows_linear").await; @@ -100,6 +118,7 @@ async fn linear_workflow_runs_all_steps_to_completion() { /// The durable case: a workflow that blocks on `wait_for_event` must suspend in /// `waiting`, naming the event it needs — and stay there until something fires /// that event. Firing it from a separate RPC must resume the run to completion. +#[cfg(feature = "testcontainers")] #[tokio::test] async fn gated_workflow_blocks_on_event_then_resumes() { let app = start_app("workflows_gated").await; @@ -183,6 +202,7 @@ async fn gated_workflow_blocks_on_event_then_resumes() { /// `subscribe-workflow` hands back the workflow's current state synchronously, /// at subscribe time — the equivalent of a store's first value. +#[cfg(feature = "testcontainers")] #[tokio::test] async fn subscribe_workflow_returns_initial_snapshot() { let app = start_app("workflows_snapshot").await; diff --git a/crates/forge-macros/Cargo.toml b/crates/forge-macros/Cargo.toml index 328ea1dc..df76bd93 100644 --- a/crates/forge-macros/Cargo.toml +++ b/crates/forge-macros/Cargo.toml @@ -17,6 +17,8 @@ syn = { workspace = true } darling = { workspace = true } quote = { workspace = true } proc-macro2 = { workspace = true } +proc-macro-crate = { workspace = true } sqlparser = { workspace = true } cron = { workspace = true } +chrono-tz = { workspace = true } blake3 = { workspace = true } diff --git a/crates/forge-macros/src/cron.rs b/crates/forge-macros/src/cron.rs index 5a526c73..09913c8f 100644 --- a/crates/forge-macros/src/cron.rs +++ b/crates/forge-macros/src/cron.rs @@ -55,6 +55,7 @@ struct CronAttrs { } pub fn cron_impl(attr: TokenStream, item: TokenStream) -> TokenStream { + let forge = crate::utils::forge_path(); let input = parse_macro_input!(item as ItemFn); let attr_args = match NestedMeta::parse_meta_list(attr.into()) { @@ -145,26 +146,48 @@ pub fn cron_impl(attr: TokenStream, item: TokenStream) -> TokenStream { let _vis = &input.vis; let block = &input.block; - let schedule = attrs.schedule.unwrap_or_else(|| "* * * * *".to_string()); + // Require explicit schedule — silent fallback to `* * * * *` (every + // minute) is almost always wrong. + let Some(raw_schedule) = attrs.schedule else { + return syn::Error::new_spanned( + &input.sig.ident, + "cron handlers require an explicit schedule. Use a positional cron \ + expression, `schedule = \"...\"`, `every = \"...\"`, or `daily_at = \"...\"`.", + ) + .to_compile_error() + .into(); + }; - // Normalize 5-part to 6-part (prepend seconds) to match what CronSchedule::new does. - { - let parts: Vec<&str> = schedule.split_whitespace().collect(); + // Normalize 5-part to 6-part (prepend seconds) to match what CronSchedule::new does, + // and pass the normalized form to the runtime so compile- and run-time agree. + let schedule = { + let parts: Vec<&str> = raw_schedule.split_whitespace().collect(); let normalized = if parts.len() == 5 { - format!("0 {schedule}") + format!("0 {raw_schedule}") } else { - schedule.clone() + raw_schedule.clone() }; if cron::Schedule::from_str(&normalized).is_err() { return syn::Error::new_spanned( &input.sig.ident, - format!("Invalid cron schedule: \"{schedule}\""), + format!("Invalid cron schedule: \"{raw_schedule}\""), ) .to_compile_error() .into(); } - } + normalized + }; let timezone = attrs.timezone.unwrap_or_else(|| "UTC".to_string()); + if timezone.parse::().is_err() { + return syn::Error::new_spanned( + &input.sig.ident, + format!( + "Invalid timezone: \"{timezone}\". Must be an IANA tz database name (e.g., \"UTC\", \"America/New_York\")." + ), + ) + .to_compile_error() + .into(); + } let group = attrs.group.unwrap_or_else(|| "default".to_string()); let catch_up = attrs.catch_up; let catch_up_limit = attrs.catch_up_limit.unwrap_or(10); @@ -185,7 +208,7 @@ pub fn cron_impl(attr: TokenStream, item: TokenStream) -> TokenStream { let registration = if attrs.register { quote! { - forge::inventory::submit!(forge::AutoHandler(|registries| { + #forge::inventory::submit!(#forge::AutoHandler(|registries| { registries.crons.register::<#struct_name>(); })); } @@ -202,15 +225,15 @@ pub fn cron_impl(attr: TokenStream, item: TokenStream) -> TokenStream { #(#other_attrs)* pub struct #struct_name; - impl forge::forge_core::__sealed::Sealed for #struct_name {} + impl #forge::forge_core::__sealed::Sealed for #struct_name {} - impl forge::forge_core::cron::ForgeCron for #struct_name { + impl #forge::forge_core::cron::ForgeCron for #struct_name { type Args = (); - fn info() -> forge::forge_core::cron::CronInfo { - forge::forge_core::cron::CronInfo { + fn info() -> #forge::forge_core::cron::CronInfo { + #forge::forge_core::cron::CronInfo { name: #rpc_name, - schedule: forge::forge_core::cron::CronSchedule::new_validated(#schedule), + schedule: #forge::forge_core::cron::CronSchedule::new_validated(#schedule), timezone: #timezone, group: #group, catch_up: #catch_up, @@ -221,8 +244,8 @@ pub fn cron_impl(attr: TokenStream, item: TokenStream) -> TokenStream { } fn execute( - ctx: &forge::forge_core::cron::CronContext, - ) -> std::pin::Pin> + Send + '_>> { + ctx: &#forge::forge_core::cron::CronContext, + ) -> std::pin::Pin> + Send + '_>> { Box::pin(async move #block) } } diff --git a/crates/forge-macros/src/daemon.rs b/crates/forge-macros/src/daemon.rs index 4dfd588f..665260a8 100644 --- a/crates/forge-macros/src/daemon.rs +++ b/crates/forge-macros/src/daemon.rs @@ -45,6 +45,7 @@ struct DaemonAttrs { } pub fn daemon_impl(attr: TokenStream, item: TokenStream) -> TokenStream { + let forge = crate::utils::forge_path(); let input = parse_macro_input!(item as ItemFn); let attr_args = match NestedMeta::parse_meta_list(attr.into()) { @@ -108,7 +109,7 @@ pub fn daemon_impl(attr: TokenStream, item: TokenStream) -> TokenStream { let registration = if attrs.register { quote! { - forge::inventory::submit!(forge::AutoHandler(|registries| { + #forge::inventory::submit!(#forge::AutoHandler(|registries| { registries.daemons.register::<#struct_name>(); })); } @@ -125,11 +126,11 @@ pub fn daemon_impl(attr: TokenStream, item: TokenStream) -> TokenStream { #(#other_attrs)* pub struct #struct_name; - impl forge::forge_core::__sealed::Sealed for #struct_name {} + impl #forge::forge_core::__sealed::Sealed for #struct_name {} - impl forge::forge_core::daemon::ForgeDaemon for #struct_name { - fn info() -> forge::forge_core::daemon::DaemonInfo { - forge::forge_core::daemon::DaemonInfo { + impl #forge::forge_core::daemon::ForgeDaemon for #struct_name { + fn info() -> #forge::forge_core::daemon::DaemonInfo { + #forge::forge_core::daemon::DaemonInfo { name: #rpc_name, leader_elected: #leader_elected, restart_on_panic: #restart_on_panic, @@ -141,8 +142,8 @@ pub fn daemon_impl(attr: TokenStream, item: TokenStream) -> TokenStream { } fn execute( - ctx: &forge::forge_core::daemon::DaemonContext, - ) -> std::pin::Pin> + Send + '_>> { + ctx: &#forge::forge_core::daemon::DaemonContext, + ) -> std::pin::Pin> + Send + '_>> { Box::pin(async move #block) } } diff --git a/crates/forge-macros/src/job.rs b/crates/forge-macros/src/job.rs index ce03f292..ffa8edd2 100644 --- a/crates/forge-macros/src/job.rs +++ b/crates/forge-macros/src/job.rs @@ -129,6 +129,7 @@ struct JobAttrs { } pub fn job_impl(attr: TokenStream, item: TokenStream) -> TokenStream { + let forge = crate::utils::forge_path(); let input = parse_macro_input!(item as ItemFn); let attr_args = match NestedMeta::parse_meta_list(attr.into()) { @@ -167,7 +168,18 @@ pub fn job_impl(attr: TokenStream, item: TokenStream) -> TokenStream { let mut args_type = quote! { () }; let mut args_ident = format_ident!("_args"); - for input_arg in input.sig.inputs.iter().skip(1) { + let user_args: Vec<_> = input.sig.inputs.iter().skip(1).collect(); + if user_args.len() > 1 { + return TokenStream::from( + syn::Error::new_spanned( + user_args[1], + "jobs may take at most one user argument (besides the JobContext). \ + Wrap multiple values in a single struct that derives Serialize/Deserialize.", + ) + .into_compile_error(), + ); + } + for input_arg in user_args { if let syn::FnArg::Typed(pat_type) = input_arg { if let syn::Pat::Ident(ident) = pat_type.pat.as_ref() { args_ident = ident.ident.clone(); @@ -224,27 +236,27 @@ pub fn job_impl(attr: TokenStream, item: TokenStream) -> TokenStream { let priority = if let Some(ref p) = attrs.priority { let p_lower = p.to_lowercase(); match p_lower.as_str() { - "background" => quote! { forge::forge_core::job::JobPriority::Background }, - "low" => quote! { forge::forge_core::job::JobPriority::Low }, - "normal" => quote! { forge::forge_core::job::JobPriority::Normal }, - "high" => quote! { forge::forge_core::job::JobPriority::High }, - "critical" => quote! { forge::forge_core::job::JobPriority::Critical }, - _ => quote! { forge::forge_core::job::JobPriority::Normal }, + "background" => quote! { #forge::forge_core::job::JobPriority::Background }, + "low" => quote! { #forge::forge_core::job::JobPriority::Low }, + "normal" => quote! { #forge::forge_core::job::JobPriority::Normal }, + "high" => quote! { #forge::forge_core::job::JobPriority::High }, + "critical" => quote! { #forge::forge_core::job::JobPriority::Critical }, + _ => quote! { #forge::forge_core::job::JobPriority::Normal }, } } else { - quote! { forge::forge_core::job::JobPriority::Normal } + quote! { #forge::forge_core::job::JobPriority::Normal } }; let max_attempts = attrs.max_attempts.unwrap_or(3); let backoff = if let Some(ref b) = attrs.backoff { match b.as_str() { - "fixed" => quote! { forge::forge_core::job::BackoffStrategy::Fixed }, - "linear" => quote! { forge::forge_core::job::BackoffStrategy::Linear }, - "exponential" => quote! { forge::forge_core::job::BackoffStrategy::Exponential }, - _ => quote! { forge::forge_core::job::BackoffStrategy::Exponential }, + "fixed" => quote! { #forge::forge_core::job::BackoffStrategy::Fixed }, + "linear" => quote! { #forge::forge_core::job::BackoffStrategy::Linear }, + "exponential" => quote! { #forge::forge_core::job::BackoffStrategy::Exponential }, + _ => quote! { #forge::forge_core::job::BackoffStrategy::Exponential }, } } else { - quote! { forge::forge_core::job::BackoffStrategy::Exponential } + quote! { #forge::forge_core::job::BackoffStrategy::Exponential } }; let max_backoff = if let Some(ref mb) = attrs.max_backoff { @@ -285,10 +297,10 @@ pub fn job_impl(attr: TokenStream, item: TokenStream) -> TokenStream { let compensation_args_ident = format_ident!("_comp_args"); quote! { fn compensate( - ctx: &forge::forge_core::job::JobContext, + ctx: &#forge::forge_core::job::JobContext, #compensation_args_ident: Self::Args, reason: &str, - ) -> std::pin::Pin> + Send + '_>> { + ) -> std::pin::Pin> + Send + '_>> { Box::pin(async move { #handler_ident(ctx, #compensation_args_ident, reason).await }) } } @@ -300,7 +312,7 @@ pub fn job_impl(attr: TokenStream, item: TokenStream) -> TokenStream { let registration = if attrs.register { quote! { - forge::inventory::submit!(forge::AutoHandler(|registries| { + #forge::inventory::submit!(#forge::AutoHandler(|registries| { registries.jobs.register::<#struct_name>(); })); } @@ -317,20 +329,20 @@ pub fn job_impl(attr: TokenStream, item: TokenStream) -> TokenStream { #(#other_attrs)* pub struct #struct_name; - impl forge::forge_core::__sealed::Sealed for #struct_name {} + impl #forge::forge_core::__sealed::Sealed for #struct_name {} - impl forge::forge_core::job::ForgeJob for #struct_name { + impl #forge::forge_core::job::ForgeJob for #struct_name { type Args = #args_type; type Output = #output_type; - fn info() -> forge::forge_core::job::JobInfo { - forge::forge_core::job::JobInfo { + fn info() -> #forge::forge_core::job::JobInfo { + #forge::forge_core::job::JobInfo { name: #fn_name_str, description: #description_tokens, timeout: #timeout, http_timeout: #http_timeout, priority: #priority, - retry: forge::forge_core::job::RetryConfig { + retry: #forge::forge_core::job::RetryConfig { max_attempts: #max_attempts, backoff: #backoff, max_backoff: #max_backoff, @@ -346,9 +358,9 @@ pub fn job_impl(attr: TokenStream, item: TokenStream) -> TokenStream { } fn execute( - ctx: &forge::forge_core::job::JobContext, + ctx: &#forge::forge_core::job::JobContext, #args_ident: Self::Args, - ) -> std::pin::Pin> + Send + '_>> { + ) -> std::pin::Pin> + Send + '_>> { Box::pin(async move #block) } diff --git a/crates/forge-macros/src/mcp_tool.rs b/crates/forge-macros/src/mcp_tool.rs index 8e02ffb2..d14dc03e 100644 --- a/crates/forge-macros/src/mcp_tool.rs +++ b/crates/forge-macros/src/mcp_tool.rs @@ -84,9 +84,23 @@ struct McpToolAttrs { } fn convert_mcp_tool_attrs(darling: DarlingMcpToolAttrs) -> Result { - let timeout = darling - .timeout - .and_then(|s| parse_duration_secs(&s).or_else(|| s.parse::().ok())); + // Require a unit suffix on timeouts to match every other macro. Bare + // integers like `timeout = "30"` are ambiguous (seconds? milliseconds?) + // and were only accepted here historically. + let timeout = match darling.timeout { + Some(ref s) => match parse_duration_secs(s) { + Some(t) => Some(t), + None => { + return Err(syn::Error::new( + proc_macro2::Span::call_site(), + format!( + "invalid timeout \"{s}\": use a duration string like \"30s\", \"5m\", or \"1h\"" + ), + )); + } + }, + None => None, + }; let (rate_limit_requests, rate_limit_per_secs, rate_limit_key) = if let Some(ref rl) = darling.rate_limit { @@ -159,6 +173,12 @@ fn tool_type_stem(fn_name: &str) -> &str { } fn expand_mcp_tool_impl(input: ItemFn, attrs: McpToolAttrs) -> syn::Result { + let forge = crate::utils::forge_path(); + // schemars(crate = "...") needs a literal string. Build it from the + // resolved forge prefix at expansion time so a renamed dep still emits + // a working path. Tokens like `::forge` and `::forgex` render with the + // leading colons, which schemars accepts. + let schemars_crate_str = format!("{}::forge_core::schemars", forge); let fn_name = &input.sig.ident; let fn_name_str = attrs.name.unwrap_or_else(|| fn_name.to_string()); validate_tool_name(&fn_name_str)?; @@ -212,8 +232,7 @@ fn expand_mcp_tool_impl(input: ItemFn, attrs: McpToolAttrs) -> syn::Result = params.iter().skip(1).cloned().collect(); @@ -359,8 +378,8 @@ fn expand_mcp_tool_impl(input: ItemFn, attrs: McpToolAttrs) -> syn::Result syn::Result syn::Result forge::forge_core::Result<#output_type> #fn_block + #vis async fn #fn_name(#ctx_name: #ctx_type) -> #forge::forge_core::Result<#output_type> #fn_block } } else { quote! { #(#fn_attrs)* - #vis async fn #fn_name(#ctx_name: #ctx_type, #(#arg_params),*) -> forge::forge_core::Result<#output_type> #fn_block + #vis async fn #fn_name(#ctx_name: #ctx_type, #(#arg_params),*) -> #forge::forge_core::Result<#output_type> #fn_block } } } else if arg_names.is_empty() { quote! { #(#fn_attrs)* - #vis async fn #fn_name(#ctx_name: &#ctx_type) -> forge::forge_core::Result<#output_type> #fn_block + #vis async fn #fn_name(#ctx_name: &#ctx_type) -> #forge::forge_core::Result<#output_type> #fn_block } } else { quote! { #(#fn_attrs)* - #vis async fn #fn_name(#ctx_name: &#ctx_type, #(#arg_params),*) -> forge::forge_core::Result<#output_type> #fn_block + #vis async fn #fn_name(#ctx_name: &#ctx_type, #(#arg_params),*) -> #forge::forge_core::Result<#output_type> #fn_block } }; let registration = if attrs.register { quote! { - forge::inventory::submit!(forge::AutoHandler(|registries| { + #forge::inventory::submit!(#forge::AutoHandler(|registries| { registries.mcp_tools.register::<#struct_name>(); })); } @@ -433,14 +452,14 @@ fn expand_mcp_tool_impl(input: ItemFn, attrs: McpToolAttrs) -> syn::Result forge::forge_core::McpToolInfo { - forge::forge_core::McpToolInfo { + fn info() -> #forge::forge_core::McpToolInfo { + #forge::forge_core::McpToolInfo { name: #fn_name_str, title: #title, description: #description, @@ -450,7 +469,7 @@ fn expand_mcp_tool_impl(input: ItemFn, attrs: McpToolAttrs) -> syn::Result syn::Result std::pin::Pin> + Send + '_>> { + ) -> std::pin::Pin> + Send + '_>> { Box::pin(async move { #execute_call }) diff --git a/crates/forge-macros/src/model.rs b/crates/forge-macros/src/model.rs index 0737c0c8..5eac3e83 100644 --- a/crates/forge-macros/src/model.rs +++ b/crates/forge-macros/src/model.rs @@ -1,7 +1,7 @@ use proc_macro::TokenStream; use proc_macro2::TokenStream as TokenStream2; use quote::quote; -use syn::{Data, DeriveInput, Fields, Meta, parse_macro_input, spanned::Spanned}; +use syn::{Data, DeriveInput, Fields, parse_macro_input, spanned::Spanned}; pub fn expand_model(attr: TokenStream, item: TokenStream) -> TokenStream { let input_clone = item.clone(); @@ -18,6 +18,7 @@ fn expand_model_impl( input: DeriveInput, _original_tokens: TokenStream2, ) -> syn::Result { + let forge = crate::utils::forge_path(); let attr_str = attr.to_string(); let trimmed = attr_str.trim(); if !trimmed.is_empty() { @@ -55,8 +56,8 @@ fn expand_model_impl( quote! { { - let rust_type = forge::forge_core::schema::RustType::from_type_string(#type_str); - let mut field = forge::forge_core::schema::FieldDef::new(#name, rust_type); + let rust_type = #forge::forge_core::schema::RustType::from_type_string(#type_str); + let mut field = #forge::forge_core::schema::FieldDef::new(#name, rust_type); field.column_name = #column_name.to_string(); field } @@ -89,11 +90,11 @@ fn expand_model_impl( #(#field_defs),* } - impl forge::forge_core::schema::ModelMeta for #struct_name { + impl #forge::forge_core::schema::ModelMeta for #struct_name { const TABLE_NAME: &'static str = #table_name; - fn table_def() -> forge::forge_core::schema::TableDef { - let mut table = forge::forge_core::schema::TableDef::new(#table_name, stringify!(#struct_name)); + fn table_def() -> #forge::forge_core::schema::TableDef { + let mut table = #forge::forge_core::schema::TableDef::new(#table_name, stringify!(#struct_name)); table.fields = vec![ #(#field_tokens),* ]; @@ -110,18 +111,23 @@ fn expand_model_impl( } fn get_table_name(input: &DeriveInput) -> syn::Result { - // Look for #[table(name = "...")] + // Look for #[table(name = "...")]. parse_nested_meta correctly handles + // escaped quotes and inner whitespace, unlike the previous string-slice + // parser which choked on `name = "with \"escape\""` and similar. for attr in &input.attrs { if attr.path().is_ident("table") { - let meta = attr.meta.clone(); - if let Meta::List(list) = meta { - let tokens: TokenStream2 = list.tokens; - let tokens_str = tokens.to_string(); - if tokens_str.starts_with("name") - && let Some(value) = extract_string_value(&tokens_str) - { - return Ok(value); + let mut found: Option = None; + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("name") { + let lit: syn::LitStr = meta.value()?.parse()?; + found = Some(lit.value()); + Ok(()) + } else { + Err(meta.error("expected `name = \"...\"`")) } + })?; + if let Some(name) = found { + return Ok(name); } } } @@ -131,18 +137,6 @@ fn get_table_name(input: &DeriveInput) -> syn::Result { Ok(pluralize(&name)) } -fn extract_string_value(s: &str) -> Option { - // Parse "name = \"value\"" pattern - let parts: Vec<&str> = s.splitn(2, '=').collect(); - if parts.len() == 2 { - let value = parts[1].trim(); - if let Some(stripped) = value.strip_prefix('"').and_then(|s| s.strip_suffix('"')) { - return Some(stripped.to_string()); - } - } - None -} - use crate::utils::{pluralize, to_snake_case}; #[cfg(test)] @@ -197,24 +191,6 @@ mod tests { assert_eq!(pluralize("buy"), "buys"); } - #[test] - fn extract_string_value_valid() { - assert_eq!( - extract_string_value(r#"name = "custom_table""#), - Some("custom_table".to_string()) - ); - } - - #[test] - fn extract_string_value_no_quotes() { - assert_eq!(extract_string_value("name = bare_value"), None); - } - - #[test] - fn extract_string_value_no_equals() { - assert_eq!(extract_string_value(r#""just a string""#), None); - } - // --- Table name derivation (integration of to_snake_case + pluralize) --- #[test] diff --git a/crates/forge-macros/src/mutation.rs b/crates/forge-macros/src/mutation.rs index bd90cfd5..5e0a8e29 100644 --- a/crates/forge-macros/src/mutation.rs +++ b/crates/forge-macros/src/mutation.rs @@ -215,6 +215,7 @@ fn convert_mutation_attrs(darling: DarlingMutationAttrs) -> Result syn::Result { + let forge = crate::utils::forge_path(); let fn_name = &input.sig.ident; let fn_name_str = fn_name.to_string(); let rpc_name = attrs.name.as_deref().unwrap_or(&fn_name_str).to_string(); @@ -250,53 +251,27 @@ fn expand_mutation_impl(input: ItemFn, attrs: MutationAttrs) -> syn::Result, found: bool, } - impl DispatchCallVisitor { - fn receiver_root_ident(mut expr: &syn::Expr) -> Option<&syn::Ident> { - loop { - match expr { - syn::Expr::MethodCall(inner) => expr = &inner.receiver, - syn::Expr::Try(inner) => expr = &inner.expr, - syn::Expr::Await(inner) => expr = &inner.base, - syn::Expr::Paren(inner) => expr = &inner.expr, - syn::Expr::Reference(inner) => expr = &inner.expr, - syn::Expr::Path(path) => { - if path.qself.is_none() && path.path.segments.len() == 1 { - return path.path.segments.first().map(|s| &s.ident); - } - return None; - } - _ => return None, - } - } - } - - fn receiver_is_ctx(&self, receiver: &syn::Expr) -> bool { - let Some(ref ctx) = self.ctx_ident else { - return true; - }; - Self::receiver_root_ident(receiver).is_some_and(|root| root == ctx) - } - } impl<'ast> syn::visit::Visit<'ast> for DispatchCallVisitor { fn visit_expr_method_call(&mut self, node: &'ast syn::ExprMethodCall) { let method = node.method.to_string(); - if (method == "dispatch_job" || method == "start_workflow") - && self.receiver_is_ctx(&node.receiver) - { + if method == "dispatch_job" || method == "start_workflow" { self.found = true; } syn::visit::visit_expr_method_call(self, node); } } - let mut visitor = DispatchCallVisitor { - ctx_ident: mutation_ctx_ident.clone(), - found: false, - }; + let mut visitor = DispatchCallVisitor { found: false }; syn::visit::visit_block(&mut visitor, fn_block); visitor.found }; @@ -323,17 +298,50 @@ fn expand_mutation_impl(input: ItemFn, attrs: MutationAttrs) -> syn::Result, found: bool, } + impl HttpCallVisitor { + fn receiver_root_ident(mut expr: &syn::Expr) -> Option<&syn::Ident> { + loop { + match expr { + syn::Expr::MethodCall(inner) => expr = &inner.receiver, + syn::Expr::Try(inner) => expr = &inner.expr, + syn::Expr::Await(inner) => expr = &inner.base, + syn::Expr::Paren(inner) => expr = &inner.expr, + syn::Expr::Reference(inner) => expr = &inner.expr, + syn::Expr::Path(path) => { + if path.qself.is_none() && path.path.segments.len() == 1 { + return path.path.segments.first().map(|s| &s.ident); + } + return None; + } + _ => return None, + } + } + } + fn receiver_is_ctx(&self, receiver: &syn::Expr) -> bool { + let Some(ref ctx) = self.ctx_ident else { + return true; + }; + Self::receiver_root_ident(receiver).is_some_and(|root| root == ctx) + } + } impl<'ast> syn::visit::Visit<'ast> for HttpCallVisitor { fn visit_expr_method_call(&mut self, node: &'ast syn::ExprMethodCall) { - if node.method == "http" { + // Gate on the receiver root resolving to the mutation context + // binding. Without this, any builder method named `.http()` + // on an unrelated type would trip the lint. + if node.method == "http" && self.receiver_is_ctx(&node.receiver) { self.found = true; } syn::visit::visit_expr_method_call(self, node); } } - let mut visitor = HttpCallVisitor { found: false }; + let mut visitor = HttpCallVisitor { + ctx_ident: mutation_ctx_ident.clone(), + found: false, + }; syn::visit::visit_block(&mut visitor, fn_block); if visitor.found { return Err(syn::Error::new_spanned( @@ -388,8 +396,7 @@ fn expand_mutation_impl(input: ItemFn, attrs: MutationAttrs) -> syn::Result = params.iter().skip(1).cloned().collect(); @@ -479,16 +486,18 @@ fn expand_mutation_impl(input: ItemFn, attrs: MutationAttrs) -> syn::Result { let key_tokens = match k.as_str() { - "user" => quote! { forge::forge_core::rate_limit::RateLimitKey::User }, - "ip" => quote! { forge::forge_core::rate_limit::RateLimitKey::Ip }, - "tenant" => quote! { forge::forge_core::rate_limit::RateLimitKey::Tenant }, - "user_action" => quote! { forge::forge_core::rate_limit::RateLimitKey::UserAction }, - "global" => quote! { forge::forge_core::rate_limit::RateLimitKey::Global }, + "user" => quote! { #forge::forge_core::rate_limit::RateLimitKey::User }, + "ip" => quote! { #forge::forge_core::rate_limit::RateLimitKey::Ip }, + "tenant" => quote! { #forge::forge_core::rate_limit::RateLimitKey::Tenant }, + "user_action" => { + quote! { #forge::forge_core::rate_limit::RateLimitKey::UserAction } + } + "global" => quote! { #forge::forge_core::rate_limit::RateLimitKey::Global }, _ if k.starts_with("custom:") => { let claim = k.trim_start_matches("custom:"); - quote! { forge::forge_core::rate_limit::RateLimitKey::Custom(#claim.to_string()) } + quote! { #forge::forge_core::rate_limit::RateLimitKey::Custom(#claim.to_string()) } } - _ => quote! { forge::forge_core::rate_limit::RateLimitKey::User }, + _ => quote! { #forge::forge_core::rate_limit::RateLimitKey::User }, }; quote! { Some(#key_tokens) } } @@ -498,13 +507,13 @@ fn expand_mutation_impl(input: ItemFn, attrs: MutationAttrs) -> syn::Result { let level_tokens = match l.as_str() { - "trace" => quote! { forge::forge_core::LogLevel::Trace }, - "debug" => quote! { forge::forge_core::LogLevel::Debug }, - "info" => quote! { forge::forge_core::LogLevel::Info }, - "warn" => quote! { forge::forge_core::LogLevel::Warn }, - "error" => quote! { forge::forge_core::LogLevel::Error }, - "off" => quote! { forge::forge_core::LogLevel::Off }, - _ => quote! { forge::forge_core::LogLevel::Trace }, + "trace" => quote! { #forge::forge_core::LogLevel::Trace }, + "debug" => quote! { #forge::forge_core::LogLevel::Debug }, + "info" => quote! { #forge::forge_core::LogLevel::Info }, + "warn" => quote! { #forge::forge_core::LogLevel::Warn }, + "error" => quote! { #forge::forge_core::LogLevel::Error }, + "off" => quote! { #forge::forge_core::LogLevel::Off }, + _ => quote! { #forge::forge_core::LogLevel::Trace }, }; quote! { Some(#level_tokens) } } @@ -523,6 +532,12 @@ fn expand_mutation_impl(input: ItemFn, attrs: MutationAttrs) -> syn::Result = if let Some(ref tables) = attrs.tables { tables.clone() } else { + if let Some(issue) = extractor.issues.first() { + return Err(syn::Error::new_spanned( + &input.sig.ident, + issue.describe(&fn_name_str, "mutation"), + )); + } match extract_tables_from_sql(&extractor.sql_strings) { TableExtractionResult::Ok(tables) => { let mut sorted: Vec = tables.into_iter().collect(); @@ -671,29 +686,29 @@ fn expand_mutation_impl(input: ItemFn, attrs: MutationAttrs) -> syn::Result forge::forge_core::Result<#output_type> #fn_block + #vis async fn #fn_name(#ctx_name: #ctx_type) -> #forge::forge_core::Result<#output_type> #fn_block } } else { quote! { #(#fn_attrs)* - #vis async fn #fn_name(#ctx_name: #ctx_type, #(#arg_params),*) -> forge::forge_core::Result<#output_type> #fn_block + #vis async fn #fn_name(#ctx_name: #ctx_type, #(#arg_params),*) -> #forge::forge_core::Result<#output_type> #fn_block } } } else if arg_names.is_empty() { quote! { #(#fn_attrs)* - #vis async fn #fn_name(#ctx_name: &#ctx_type) -> forge::forge_core::Result<#output_type> #fn_block + #vis async fn #fn_name(#ctx_name: &#ctx_type) -> #forge::forge_core::Result<#output_type> #fn_block } } else { quote! { #(#fn_attrs)* - #vis async fn #fn_name(#ctx_name: &#ctx_type, #(#arg_params),*) -> forge::forge_core::Result<#output_type> #fn_block + #vis async fn #fn_name(#ctx_name: &#ctx_type, #(#arg_params),*) -> #forge::forge_core::Result<#output_type> #fn_block } }; let registration = if attrs.register { quote! { - forge::inventory::submit!(forge::AutoHandler(|registries| { + #forge::inventory::submit!(#forge::AutoHandler(|registries| { registries.functions.register_mutation::<#struct_name>(); })); } @@ -711,17 +726,17 @@ fn expand_mutation_impl(input: ItemFn, attrs: MutationAttrs) -> syn::Result forge::forge_core::FunctionInfo { - forge::forge_core::FunctionInfo { + fn info() -> #forge::forge_core::FunctionInfo { + #forge::forge_core::FunctionInfo { name: #rpc_name, description: #description, - kind: forge::forge_core::FunctionKind::Mutation, + kind: #forge::forge_core::FunctionKind::Mutation, required_role: #required_role, is_public: #is_public, cache_ttl: None, @@ -742,9 +757,9 @@ fn expand_mutation_impl(input: ItemFn, attrs: MutationAttrs) -> syn::Result std::pin::Pin> + Send + '_>> { + ) -> std::pin::Pin> + Send + '_>> { Box::pin(async move { #execute_call }) diff --git a/crates/forge-macros/src/query.rs b/crates/forge-macros/src/query.rs index d6988c0a..a6503c91 100644 --- a/crates/forge-macros/src/query.rs +++ b/crates/forge-macros/src/query.rs @@ -209,6 +209,7 @@ fn convert_query_attrs(darling: DarlingQueryAttrs) -> Result syn::Result { + let forge = crate::utils::forge_path(); let fn_name = &input.sig.ident; let fn_name_str = fn_name.to_string(); let rpc_name = attrs.name.as_deref().unwrap_or(&fn_name_str).to_string(); @@ -258,8 +259,7 @@ fn expand_query_impl(input: ItemFn, attrs: QueryAttrs) -> syn::Result = if let Some(explicit_tables) = attrs.tables { @@ -268,6 +268,16 @@ fn expand_query_impl(input: ItemFn, attrs: QueryAttrs) -> syn::Result { let mut sorted: Vec = tables.into_iter().collect(); @@ -458,16 +468,18 @@ fn expand_query_impl(input: ItemFn, attrs: QueryAttrs) -> syn::Result { let key_tokens = match k.as_str() { - "user" => quote! { forge::forge_core::rate_limit::RateLimitKey::User }, - "ip" => quote! { forge::forge_core::rate_limit::RateLimitKey::Ip }, - "tenant" => quote! { forge::forge_core::rate_limit::RateLimitKey::Tenant }, - "user_action" => quote! { forge::forge_core::rate_limit::RateLimitKey::UserAction }, - "global" => quote! { forge::forge_core::rate_limit::RateLimitKey::Global }, + "user" => quote! { #forge::forge_core::rate_limit::RateLimitKey::User }, + "ip" => quote! { #forge::forge_core::rate_limit::RateLimitKey::Ip }, + "tenant" => quote! { #forge::forge_core::rate_limit::RateLimitKey::Tenant }, + "user_action" => { + quote! { #forge::forge_core::rate_limit::RateLimitKey::UserAction } + } + "global" => quote! { #forge::forge_core::rate_limit::RateLimitKey::Global }, _ if k.starts_with("custom:") => { let claim = k.trim_start_matches("custom:"); - quote! { forge::forge_core::rate_limit::RateLimitKey::Custom(#claim.to_string()) } + quote! { #forge::forge_core::rate_limit::RateLimitKey::Custom(#claim.to_string()) } } - _ => quote! { forge::forge_core::rate_limit::RateLimitKey::User }, + _ => quote! { #forge::forge_core::rate_limit::RateLimitKey::User }, }; quote! { Some(#key_tokens) } } @@ -477,13 +489,13 @@ fn expand_query_impl(input: ItemFn, attrs: QueryAttrs) -> syn::Result { let level_tokens = match l.as_str() { - "trace" => quote! { forge::forge_core::LogLevel::Trace }, - "debug" => quote! { forge::forge_core::LogLevel::Debug }, - "info" => quote! { forge::forge_core::LogLevel::Info }, - "warn" => quote! { forge::forge_core::LogLevel::Warn }, - "error" => quote! { forge::forge_core::LogLevel::Error }, - "off" => quote! { forge::forge_core::LogLevel::Off }, - _ => quote! { forge::forge_core::LogLevel::Trace }, + "trace" => quote! { #forge::forge_core::LogLevel::Trace }, + "debug" => quote! { #forge::forge_core::LogLevel::Debug }, + "info" => quote! { #forge::forge_core::LogLevel::Info }, + "warn" => quote! { #forge::forge_core::LogLevel::Warn }, + "error" => quote! { #forge::forge_core::LogLevel::Error }, + "off" => quote! { #forge::forge_core::LogLevel::Off }, + _ => quote! { #forge::forge_core::LogLevel::Trace }, }; quote! { Some(#level_tokens) } } @@ -551,29 +563,29 @@ fn expand_query_impl(input: ItemFn, attrs: QueryAttrs) -> syn::Result forge::forge_core::Result<#output_type> #fn_block + #vis async fn #fn_name(#ctx_name: #ctx_type) -> #forge::forge_core::Result<#output_type> #fn_block } } else { quote! { #(#fn_attrs)* - #vis async fn #fn_name(#ctx_name: #ctx_type, #(#arg_params),*) -> forge::forge_core::Result<#output_type> #fn_block + #vis async fn #fn_name(#ctx_name: #ctx_type, #(#arg_params),*) -> #forge::forge_core::Result<#output_type> #fn_block } } } else if arg_names.is_empty() { quote! { #(#fn_attrs)* - #vis async fn #fn_name(#ctx_name: &#ctx_type) -> forge::forge_core::Result<#output_type> #fn_block + #vis async fn #fn_name(#ctx_name: &#ctx_type) -> #forge::forge_core::Result<#output_type> #fn_block } } else { quote! { #(#fn_attrs)* - #vis async fn #fn_name(#ctx_name: &#ctx_type, #(#arg_params),*) -> forge::forge_core::Result<#output_type> #fn_block + #vis async fn #fn_name(#ctx_name: &#ctx_type, #(#arg_params),*) -> #forge::forge_core::Result<#output_type> #fn_block } }; let registration = if attrs.register { quote! { - forge::inventory::submit!(forge::AutoHandler(|registries| { + #forge::inventory::submit!(#forge::AutoHandler(|registries| { registries.functions.register_query::<#struct_name>(); })); } @@ -591,17 +603,17 @@ fn expand_query_impl(input: ItemFn, attrs: QueryAttrs) -> syn::Result forge::forge_core::FunctionInfo { - forge::forge_core::FunctionInfo { + fn info() -> #forge::forge_core::FunctionInfo { + #forge::forge_core::FunctionInfo { name: #rpc_name, description: #description, - kind: forge::forge_core::FunctionKind::Query, + kind: #forge::forge_core::FunctionKind::Query, required_role: #required_role, is_public: #is_public, cache_ttl: #cache_ttl, @@ -622,9 +634,9 @@ fn expand_query_impl(input: ItemFn, attrs: QueryAttrs) -> syn::Result std::pin::Pin> + Send + '_>> { + ) -> std::pin::Pin> + Send + '_>> { Box::pin(async move { #execute_call }) diff --git a/crates/forge-macros/src/sql_extractor.rs b/crates/forge-macros/src/sql_extractor.rs index 2a8436f3..fedb115a 100644 --- a/crates/forge-macros/src/sql_extractor.rs +++ b/crates/forge-macros/src/sql_extractor.rs @@ -11,6 +11,53 @@ use sqlparser::parser::Parser; use syn::visit::Visit; use syn::{Expr as SynExpr, ExprCall, ExprLit, ExprMacro, ExprMethodCall}; +/// Reasons that SQL extraction can't reason about the call site at all. +/// Surfaced to callers so they can emit a clear compile error directing the +/// user to set explicit `tables(...)`. +#[derive(Debug, Clone)] +pub enum SqlAnalysisIssue { + /// `sqlx::query(&some_string)` or `sqlx::query_as::<_, T>(...)` — runtime + /// variant that bypasses compile-time checking entirely. + RuntimeSqlxCall, + /// Inside a sqlx::query!{} macro, the SQL is built via `format!`, + /// `String::from`, `concat!`, or other non-literal construction. + DynamicSqlInMacro, + /// SQL string is hoisted into a `const`/`let` binding or `include_str!`, + /// so the macro can't see the literal at the call site. + HoistedSqlBinding, + /// `sqlx::query!{}` received a byte-string literal which would otherwise + /// be silently dropped. + ByteStringInMacro, +} + +impl SqlAnalysisIssue { + pub fn describe(&self, fn_name: &str, macro_kind: &str) -> String { + let header = match self { + Self::RuntimeSqlxCall => format!( + "`{fn_name}` calls runtime `sqlx::query()`/`sqlx::query_as::<_, T>()`. \ + Use the `sqlx::query!` / `sqlx::query_as!` macros for compile-time checks." + ), + Self::DynamicSqlInMacro => format!( + "`{fn_name}` builds SQL dynamically (e.g. `format!`, `String::from`, `concat!`) \ + inside a `sqlx::query!` macro. Table dependencies and the scope lint cannot be \ + verified." + ), + Self::HoistedSqlBinding => format!( + "`{fn_name}` references SQL via `const`, `let`, or `include_str!` inside a \ + `sqlx::query!` macro. The literal is invisible to the extractor." + ), + Self::ByteStringInMacro => format!( + "`{fn_name}` passes a byte-string literal to a `sqlx::query!` macro. \ + SQL must be a regular string literal." + ), + }; + format!( + "{header}\n\ + Add #[{macro_kind}(tables(\"your_table\"))] to declare table dependencies explicitly." + ) + } +} + /// Detects `.pool()` calls in a handler body, signalling DB work delegated /// to a helper function whose SQL is invisible to `SqlStringExtractor`. pub struct DbDelegationDetector { @@ -35,12 +82,17 @@ impl<'ast> Visit<'ast> for DbDelegationDetector { /// Visitor that extracts SQL string literals from function bodies. pub struct SqlStringExtractor { pub sql_strings: Vec, + /// Patterns that defeat static SQL analysis. Callers should treat any + /// non-empty list as a hard compile error unless explicit `tables(...)` + /// was provided. + pub issues: Vec, } impl SqlStringExtractor { pub fn new() -> Self { Self { sql_strings: Vec::new(), + issues: Vec::new(), } } @@ -83,6 +135,13 @@ impl SqlStringExtractor { match token { proc_macro2::TokenTree::Literal(lit) => { let lit_str = lit.to_string(); + // Reject byte strings outright — they parse as syn::LitStr + // failures and would otherwise be silently dropped. + let trimmed = lit_str.trim_start(); + if trimmed.starts_with("b\"") || trimmed.starts_with("br") { + self.issues.push(SqlAnalysisIssue::ByteStringInMacro); + continue; + } if let Some(sql) = Self::extract_string_content(&lit_str) && Self::looks_like_sql(&sql) { @@ -97,6 +156,75 @@ impl SqlStringExtractor { } } + /// Inspect the first token-stream argument passed to a `sqlx::query!` + /// macro and decide whether the SQL is recoverable as a literal. Flags + /// `format!(...)`, `concat!(...)`, `String::from(...)`, `include_str!`, + /// and bare identifier references (hoisted into `const SQL` or `let sql`). + fn check_macro_first_arg(&mut self, tokens: &proc_macro2::TokenStream) { + // Peek at the leading token sequence before the first `,` separator. + let mut head: Vec = Vec::new(); + for tt in tokens.clone() { + if let proc_macro2::TokenTree::Punct(ref p) = tt + && p.as_char() == ',' + { + break; + } + head.push(tt); + } + + // Strip leading `&` references — `sqlx::query!(&sql, ...)` is the + // same shape from our perspective. + let mut idx = 0; + while let Some(proc_macro2::TokenTree::Punct(p)) = head.get(idx) { + if p.as_char() == '&' { + idx += 1; + } else { + break; + } + } + let head = &head[idx..]; + + match head { + // Single string literal — handled by extract_sql_from_tokens. + [proc_macro2::TokenTree::Literal(_)] => {} + // Bare identifier: `query!(SQL)` or `query!(my_sql)` — hoisted. + [proc_macro2::TokenTree::Ident(_)] => { + self.issues.push(SqlAnalysisIssue::HoistedSqlBinding); + } + // `format!(...)`, `concat!(...)`, `include_str!(...)`, or a + // path-qualified call like `String::from(...)`. Detect by an + // ident followed by `!` or `(` / `::`. + _ if head.len() >= 2 => { + if let proc_macro2::TokenTree::Ident(first) = &head[0] { + let name = first.to_string(); + let next = &head[1]; + let is_macro_call = + matches!(next, proc_macro2::TokenTree::Punct(p) if p.as_char() == '!'); + let is_path = + matches!(next, proc_macro2::TokenTree::Punct(p) if p.as_char() == ':'); + let is_call = matches!(next, proc_macro2::TokenTree::Group(_)); + if is_macro_call + && matches!( + name.as_str(), + "format" | "concat" | "include_str" | "format_args" + ) + { + if name == "include_str" { + self.issues.push(SqlAnalysisIssue::HoistedSqlBinding); + } else { + self.issues.push(SqlAnalysisIssue::DynamicSqlInMacro); + } + } else if is_path || is_call { + // `String::from(...)`, `format!`, or general fn call — + // treat as dynamic. + self.issues.push(SqlAnalysisIssue::DynamicSqlInMacro); + } + } + } + _ => {} + } + } + /// Extract the actual string content from a literal representation. /// Delegates parsing and unescaping to syn so raw, byte, and escaped /// forms all decode through the same canonical path. @@ -130,20 +258,25 @@ impl<'ast> Visit<'ast> for SqlStringExtractor { fn visit_expr_call(&mut self, node: &'ast ExprCall) { if let SynExpr::Path(path) = &*node.func { - let path_str = path + let last = path .path .segments - .iter() + .last() .map(|s| s.ident.to_string()) - .collect::>() - .join("::"); + .unwrap_or_default(); - if (path_str.contains("query") - || path_str.ends_with("query_as") - || path_str.ends_with("raw_sql")) - && let Some(first_arg) = node.args.first() - { - self.visit_expr(first_arg); + // Runtime sqlx calls (no compile-time checks): `sqlx::query(...)`, + // `sqlx::query_as::<_, T>(...)`, `sqlx::query_scalar(...)`, etc. + // Match on the final path segment exactly — `query_helper` or + // `my_query` do not count. + if matches!( + last.as_str(), + "query" | "query_as" | "query_scalar" | "query_with" | "raw_sql" + ) { + self.issues.push(SqlAnalysisIssue::RuntimeSqlxCall); + if let Some(first_arg) = node.args.first() { + self.visit_expr(first_arg); + } } } @@ -171,13 +304,49 @@ impl<'ast> Visit<'ast> for SqlStringExtractor { macro_name.as_str(), "query" | "query_as" | "query_scalar" | "query_as_unchecked" | "query_scalar_unchecked" ) { - self.extract_sql_from_tokens(&node.mac.tokens); + // `query_as!(Type, sql, ...)` and `query_as_unchecked!(Type, sql, ...)` + // put the row type as the first arg. Skip past it so we inspect + // the actual SQL token, not the type ident. + let sql_tokens = if matches!(macro_name.as_str(), "query_as" | "query_as_unchecked") { + skip_first_macro_arg(&node.mac.tokens) + } else { + node.mac.tokens.clone() + }; + self.check_macro_first_arg(&sql_tokens); + self.extract_sql_from_tokens(&sql_tokens); } syn::visit::visit_expr_macro(self, node); } } +/// Drop the first comma-separated argument (and the comma itself) from a +/// macro's raw token stream. Used to strip the row type from +/// `query_as!(Type, sql, ...)` before inspecting the SQL token. +fn skip_first_macro_arg(tokens: &proc_macro2::TokenStream) -> proc_macro2::TokenStream { + let mut depth = 0i32; + let mut seen_comma = false; + let mut out: Vec = Vec::new(); + for tt in tokens.clone() { + if seen_comma { + out.push(tt); + continue; + } + if let proc_macro2::TokenTree::Punct(ref p) = tt { + if p.as_char() == ',' && depth == 0 { + seen_comma = true; + continue; + } + if matches!(p.as_char(), '<') { + depth += 1; + } else if matches!(p.as_char(), '>') { + depth -= 1; + } + } + } + out.into_iter().collect() +} + /// Parse SQL strings and extract all selected column names. /// Returns bare column names (without table qualifiers). pub fn extract_columns_from_sql(sql_strings: &[String]) -> HashSet { @@ -642,11 +811,32 @@ fn expr_mentions_tenant(e: &Expr) -> bool { Expr::InList { expr, list, .. } => { expr_mentions_tenant(expr) || list.iter().any(expr_mentions_tenant) } - Expr::InSubquery { expr, .. } => expr_mentions_tenant(expr), + Expr::InSubquery { expr, subquery, .. } => { + expr_mentions_tenant(expr) || query_mentions_tenant(subquery) + } Expr::Between { expr, low, high, .. } => expr_mentions_tenant(expr) || expr_mentions_tenant(low) || expr_mentions_tenant(high), Expr::IsNull(e) | Expr::IsNotNull(e) => expr_mentions_tenant(e), + // Mirror expr_has_scope so `(claims->>'tenant_id')::uuid = $1`, + // `EXISTS (SELECT ... WHERE tenant_id = $1)`, and Snowflake-style + // `obj:tenant_id` are all recognized. + Expr::Subquery(q) | Expr::Exists { subquery: q, .. } => query_mentions_tenant(q), + Expr::JsonAccess { value, path } => { + expr_mentions_tenant(value) + || path.path.iter().any(|elem| match elem { + sqlparser::ast::JsonPathElem::Dot { key, .. } => { + key.eq_ignore_ascii_case("tenant_id") + } + sqlparser::ast::JsonPathElem::Bracket { key } => match key { + Expr::Value(sqlparser::ast::Value::SingleQuotedString(s)) + | Expr::Value(sqlparser::ast::Value::DoubleQuotedString(s)) => { + s.eq_ignore_ascii_case("tenant_id") + } + _ => false, + }, + }) + } _ => false, } } @@ -717,8 +907,40 @@ fn stmt_is_scoped(stmt: &Statement) -> bool { let mut ctx = ScopeCtx::new(); match stmt { Statement::Query(q) => query_is_scoped(q, &mut ctx), - Statement::Update { selection, .. } => selection.as_ref().is_some_and(expr_has_scope), - Statement::Delete(d) => d.selection.as_ref().is_some_and(expr_has_scope), + Statement::Update { + selection, from, .. + } => { + // UPDATE ... FROM ... WHERE ... — the FROM clause can carry the + // scope predicate via a join expression. Walk both. + if selection.as_ref().is_some_and(expr_has_scope) { + return true; + } + if let Some(from) = from { + let twj = match from { + sqlparser::ast::UpdateTableFromKind::BeforeSet(t) => t, + sqlparser::ast::UpdateTableFromKind::AfterSet(t) => t, + }; + if twj_has_scope_on_join(twj) { + return true; + } + } + false + } + Statement::Delete(d) => { + if d.selection.as_ref().is_some_and(expr_has_scope) { + return true; + } + // PG-style `DELETE FROM t USING ... WHERE ...` puts the scope + // predicate on the join in USING. Walk it. + if let Some(using) = &d.using { + for twj in using { + if twj_has_scope_on_join(twj) { + return true; + } + } + } + false + } _ => false, } } @@ -838,6 +1060,27 @@ fn source_is_scoped(factor: &TableFactor, ctx: &ScopeCtx) -> bool { } } +/// True if any JOIN ON clause attached to the given TableWithJoins carries a +/// scope predicate. Used for UPDATE/DELETE where the scope often lives on a +/// join in the FROM/USING clause rather than the top-level WHERE. +fn twj_has_scope_on_join(twj: &TableWithJoins) -> bool { + for join in &twj.joins { + let constraint = match &join.join_operator { + sqlparser::ast::JoinOperator::Inner(c) + | sqlparser::ast::JoinOperator::LeftOuter(c) + | sqlparser::ast::JoinOperator::RightOuter(c) + | sqlparser::ast::JoinOperator::FullOuter(c) => c, + _ => continue, + }; + if let sqlparser::ast::JoinConstraint::On(e) = constraint + && expr_has_scope(e) + { + return true; + } + } + false +} + fn expr_has_scope(e: &Expr) -> bool { match e { Expr::Identifier(ident) => is_scope_col(&ident.value), @@ -854,11 +1097,22 @@ fn expr_has_scope(e: &Expr) -> bool { | BinaryOperator::HashLongArrow ) { expr_has_scope(left) || value_is_scope_col(right) - } else if matches!(op, BinaryOperator::Eq | BinaryOperator::NotEq) - && (is_direct_scope_ref(left) && is_literal_value(right) - || is_direct_scope_ref(right) && is_literal_value(left)) - { - false + } else if matches!(op, BinaryOperator::Eq | BinaryOperator::NotEq) { + // Scope only passes when ONE side is a direct scope reference + // (or JSON-arrow into one) AND the other side is a $param + // binding. Comparing scope col to a hardcoded literal or to + // another column doesn't bind the row to the caller. + if (is_direct_scope_ref(left) && is_placeholder_value(right)) + || (is_direct_scope_ref(right) && is_placeholder_value(left)) + { + true + } else if is_direct_scope_ref(left) || is_direct_scope_ref(right) { + // Scope col compared to a literal or to another column — + // explicitly not scoped. + false + } else { + expr_has_scope(left) || expr_has_scope(right) + } } else { expr_has_scope(left) || expr_has_scope(right) } @@ -867,12 +1121,15 @@ fn expr_has_scope(e: &Expr) -> bool { Expr::Between { expr, low, high, .. } => expr_has_scope(expr) || expr_has_scope(low) || expr_has_scope(high), - Expr::IsNull(e) - | Expr::IsNotNull(e) - | Expr::IsTrue(e) - | Expr::IsNotTrue(e) - | Expr::IsFalse(e) - | Expr::IsNotFalse(e) => expr_has_scope(e), + // IS [NOT] NULL / TRUE / FALSE never compares against a parameter, + // so even if the operand names a scope column the predicate doesn't + // bind the row to the current principal. Reject these outright. + Expr::IsNull(_) + | Expr::IsNotNull(_) + | Expr::IsTrue(_) + | Expr::IsNotTrue(_) + | Expr::IsFalse(_) + | Expr::IsNotFalse(_) => false, Expr::InList { expr, list, .. } => expr_has_scope(expr) || list.iter().any(expr_has_scope), Expr::InSubquery { expr, subquery, .. } => { let sub_scoped = query_is_scoped(subquery, &mut ScopeCtx::new()); @@ -914,12 +1171,12 @@ fn is_direct_scope_ref(e: &Expr) -> bool { } } -/// True if the expression is a non-placeholder literal. Placeholders ($1, $2) are -/// acceptable scope column counterparts; hardcoded literals are not. -fn is_literal_value(e: &Expr) -> bool { +/// True if the expression eventually reduces to a parameter placeholder +/// (`$1`, `$2`, ...). Unwraps Cast/Nested wrappers so `$1::uuid` counts. +fn is_placeholder_value(e: &Expr) -> bool { match e { - Expr::Value(v) => !matches!(v, sqlparser::ast::Value::Placeholder(_)), - Expr::Cast { expr, .. } | Expr::Nested(expr) => is_literal_value(expr), + Expr::Value(sqlparser::ast::Value::Placeholder(_)) => true, + Expr::Cast { expr, .. } | Expr::Nested(expr) => is_placeholder_value(expr), _ => false, } } diff --git a/crates/forge-macros/src/utils.rs b/crates/forge-macros/src/utils.rs index 17e61580..ba783c3d 100644 --- a/crates/forge-macros/src/utils.rs +++ b/crates/forge-macros/src/utils.rs @@ -2,8 +2,42 @@ use std::time::Duration; +use proc_macro_crate::{FoundCrate, crate_name}; use proc_macro2::TokenStream; -use quote::quote; +use quote::{format_ident, quote}; + +/// Resolve the path to the host `forge` crate at expansion time. +/// +/// The crate is published as the `forgex` package but its library is named +/// `forge` (`[lib] name` in `crates/forge/Cargo.toml`) so users write +/// `use forge::...`. `proc-macro-crate` returns the *dependency key* from the +/// consumer's `Cargo.toml`, which doesn't always equal the extern crate name +/// rustc sees: +/// +/// * `forge = { package = "forgex" }` (the scaffolded default) → key `forge`, +/// which is also the extern name. Emit `::forge`. +/// * a bare `forgex = "x"` dependency (what `cargo add forgex` produces, and +/// what `trybuild` generates) → key `forgex`, but rustc only knows the crate +/// by its lib name `forge`, so the key can't be used verbatim. Normalize the +/// package name back to the lib name. +/// * an explicit rename `myalias = { package = "forgex" }` → key `myalias`, +/// which *is* the extern name. Emit `::myalias`. +pub fn forge_path() -> TokenStream { + match crate_name("forgex") { + Ok(FoundCrate::Itself) => quote!(crate), + Ok(FoundCrate::Name(name)) => { + // proc-macro-crate hands back the dependency key; for a non-renamed + // `forgex` dep that key is the package name, but the crate is only + // reachable under its lib name `forge`. + let extern_name = if name == "forgex" { "forge" } else { &name }; + let ident = format_ident!("{}", extern_name); + quote!(::#ident) + } + // Not resolvable as a direct dependency (transitive use, or a context + // proc-macro-crate can't read). The standard binding is `forge`. + Err(_) => quote!(::forge), + } +} /// Convert a snake_case string to PascalCase. pub fn to_pascal_case(s: &str) -> String { @@ -27,15 +61,20 @@ fn parse_duration(s: &str) -> Option { } else if let Some(num) = s.strip_suffix('s') { num.parse::().ok().map(Duration::from_secs) } else if let Some(num) = s.strip_suffix('m') { - num.parse::().ok().map(|m| Duration::from_secs(m * 60)) + num.parse::() + .ok() + .and_then(|m| m.checked_mul(60)) + .map(Duration::from_secs) } else if let Some(num) = s.strip_suffix('h') { num.parse::() .ok() - .map(|h| Duration::from_secs(h * 3600)) + .and_then(|h| h.checked_mul(3600)) + .map(Duration::from_secs) } else if let Some(num) = s.strip_suffix('d') { num.parse::() .ok() - .map(|d| Duration::from_secs(d * 86400)) + .and_then(|d| d.checked_mul(86400)) + .map(Duration::from_secs) } else { // Bare integers without a unit suffix are not accepted. Require explicit // suffixes (e.g. "30s") so intent is unambiguous at the macro callsite. @@ -74,28 +113,19 @@ pub fn parse_duration_tokens(s: &str, default_secs: u64) -> TokenStream { Err(_) => invalid(), } } else if let Some(num) = s.strip_suffix('m') { - match num.parse::() { - Ok(n) => { - let secs = n * 60; - quote! { std::time::Duration::from_secs(#secs) } - } - Err(_) => invalid(), + match num.parse::().ok().and_then(|n| n.checked_mul(60)) { + Some(secs) => quote! { std::time::Duration::from_secs(#secs) }, + None => invalid(), } } else if let Some(num) = s.strip_suffix('h') { - match num.parse::() { - Ok(n) => { - let secs = n * 3600; - quote! { std::time::Duration::from_secs(#secs) } - } - Err(_) => invalid(), + match num.parse::().ok().and_then(|n| n.checked_mul(3600)) { + Some(secs) => quote! { std::time::Duration::from_secs(#secs) }, + None => invalid(), } } else if let Some(num) = s.strip_suffix('d') { - match num.parse::() { - Ok(n) => { - let secs = n * 86400; - quote! { std::time::Duration::from_secs(#secs) } - } - Err(_) => invalid(), + match num.parse::().ok().and_then(|n| n.checked_mul(86400)) { + Some(secs) => quote! { std::time::Duration::from_secs(#secs) }, + None => invalid(), } } else { let _ = default_secs; diff --git a/crates/forge-macros/src/webhook.rs b/crates/forge-macros/src/webhook.rs index 72302d08..caf0472e 100644 --- a/crates/forge-macros/src/webhook.rs +++ b/crates/forge-macros/src/webhook.rs @@ -167,6 +167,7 @@ fn parse_signature_from_meta(attr_args: &[NestedMeta]) -> Result TokenStream { + let forge = crate::utils::forge_path(); let input = parse_macro_input!(item as ItemFn); let attr_args = match NestedMeta::parse_meta_list(attr.into()) { @@ -279,24 +280,24 @@ pub fn webhook_impl(attr: TokenStream, item: TokenStream) -> TokenStream { ) { let alg_token = match alg { WebhookSignatureAlgorithm::HmacSha256 => { - quote! { forge::forge_core::webhook::SignatureAlgorithm::HmacSha256 } + quote! { #forge::forge_core::webhook::SignatureAlgorithm::HmacSha256 } } WebhookSignatureAlgorithm::StripeWebhooks => { - quote! { forge::forge_core::webhook::SignatureAlgorithm::StripeWebhooks } + quote! { #forge::forge_core::webhook::SignatureAlgorithm::StripeWebhooks } } WebhookSignatureAlgorithm::HmacSha256Base64 => { - quote! { forge::forge_core::webhook::SignatureAlgorithm::HmacSha256Base64 } + quote! { #forge::forge_core::webhook::SignatureAlgorithm::HmacSha256Base64 } } WebhookSignatureAlgorithm::Ed25519 => { - quote! { forge::forge_core::webhook::SignatureAlgorithm::Ed25519 } + quote! { #forge::forge_core::webhook::SignatureAlgorithm::Ed25519 } } }; let replay_window_tokens = match attrs.replay_window_secs { Some(secs) => quote! { #secs }, - None => quote! { forge::forge_core::webhook::DEFAULT_REPLAY_WINDOW_SECS }, + None => quote! { #forge::forge_core::webhook::DEFAULT_REPLAY_WINDOW_SECS }, }; quote! { - Some(forge::forge_core::webhook::SignatureConfig { + Some(#forge::forge_core::webhook::SignatureConfig { algorithm: #alg_token, header_name: #header, secret_env: #secret_env, @@ -312,15 +313,15 @@ pub fn webhook_impl(attr: TokenStream, item: TokenStream) -> TokenStream { match prefix { "header" => { quote! { - Some(forge::forge_core::webhook::IdempotencyConfig::new( - forge::forge_core::webhook::IdempotencySource::Header(#value) + Some(#forge::forge_core::webhook::IdempotencyConfig::new( + #forge::forge_core::webhook::IdempotencySource::Header(#value) )) } } "body" => { quote! { - Some(forge::forge_core::webhook::IdempotencyConfig::new( - forge::forge_core::webhook::IdempotencySource::Body(#value) + Some(#forge::forge_core::webhook::IdempotencyConfig::new( + #forge::forge_core::webhook::IdempotencySource::Body(#value) )) } } @@ -337,10 +338,10 @@ pub fn webhook_impl(attr: TokenStream, item: TokenStream) -> TokenStream { let registration = if attrs.register { quote! { - forge::inventory::submit!(forge::AutoHandler(|registries| { + #forge::inventory::submit!(#forge::AutoHandler(|registries| { registries.webhooks.register::<#struct_name>(); registries.functions.register_webhook_info( - forge::forge_core::FunctionInfo::from(&#struct_name::info()) + #forge::forge_core::FunctionInfo::from(&#struct_name::info()) ); })); } @@ -357,13 +358,13 @@ pub fn webhook_impl(attr: TokenStream, item: TokenStream) -> TokenStream { #(#other_attrs)* pub struct #struct_name; - impl forge::forge_core::__sealed::Sealed for #struct_name {} + impl #forge::forge_core::__sealed::Sealed for #struct_name {} - impl forge::forge_core::webhook::ForgeWebhook for #struct_name { + impl #forge::forge_core::webhook::ForgeWebhook for #struct_name { type Payload = #payload_type; - fn info() -> forge::forge_core::webhook::WebhookInfo { - forge::forge_core::webhook::WebhookInfo { + fn info() -> #forge::forge_core::webhook::WebhookInfo { + #forge::forge_core::webhook::WebhookInfo { name: #rpc_name, description: #description_tokens, path: #path, @@ -376,9 +377,9 @@ pub fn webhook_impl(attr: TokenStream, item: TokenStream) -> TokenStream { } fn execute( - ctx: &forge::forge_core::webhook::WebhookContext, + ctx: &#forge::forge_core::webhook::WebhookContext, payload: #payload_type, - ) -> std::pin::Pin> + Send + '_>> { + ) -> std::pin::Pin> + Send + '_>> { Box::pin(async move #block) } } diff --git a/crates/forge-macros/src/workflow.rs b/crates/forge-macros/src/workflow.rs index e3a8e04a..9abe894a 100644 --- a/crates/forge-macros/src/workflow.rs +++ b/crates/forge-macros/src/workflow.rs @@ -377,6 +377,31 @@ impl<'ast> Visit<'ast> for ContractExtractor { } } +/// Normalize a quote!-stringified type so signature hashes are stable across +/// toolchain upgrades that re-shuffle whitespace. Collapses runs of +/// whitespace to a single space and trims the ends. Does not re-parse — a +/// full canonicalization via syn would also normalize generics and qualified +/// paths, but the whitespace fix covers the common drift. +fn canonicalize_type_str(s: &str) -> String { + let mut out = String::with_capacity(s.len()); + let mut last_was_space = true; // suppress leading whitespace + for ch in s.chars() { + if ch.is_whitespace() { + if !last_was_space { + out.push(' '); + last_was_space = true; + } + } else { + out.push(ch); + last_was_space = false; + } + } + if out.ends_with(' ') { + out.pop(); + } + out +} + /// Derives a 32-char hex-encoded blake3 hash (128 bits) of name, version, /// step keys, wait keys, timeout, and input/output type name strings. /// @@ -438,6 +463,7 @@ fn derive_signature( } pub fn workflow_impl(attr: TokenStream, item: TokenStream) -> TokenStream { + let forge = crate::utils::forge_path(); let input = parse_macro_input!(item as ItemFn); let attr_args = match NestedMeta::parse_meta_list(attr.into()) { @@ -544,17 +570,30 @@ pub fn workflow_impl(attr: TokenStream, item: TokenStream) -> TokenStream { } }; - let version_str = attrs.version.as_deref().unwrap_or("v1"); + // Workflow versions must be explicit. A default of "v1" collides with + // a later explicit `version = "v1"` and silently keys both into the + // same WorkflowRegistry slot. + let Some(ref version_owned) = attrs.version else { + return syn::Error::new_spanned( + &input.sig.ident, + "workflow requires an explicit `version = \"...\"` attribute. \ + Pin a starting version (e.g. `version = \"v1\"`) so future revisions \ + can run alongside in-flight runs without signature collisions.", + ) + .to_compile_error() + .into(); + }; + let version_str = version_owned.as_str(); let is_public = attrs.is_public; let workflow_status = match attrs.status { WorkflowStatus::Active => { - quote! { forge::forge_core::workflow::WorkflowDefStatus::Active } + quote! { #forge::forge_core::workflow::WorkflowDefStatus::Active } } WorkflowStatus::Deprecated => { - quote! { forge::forge_core::workflow::WorkflowDefStatus::Deprecated } + quote! { #forge::forge_core::workflow::WorkflowDefStatus::Deprecated } } WorkflowStatus::Staging => { - quote! { forge::forge_core::workflow::WorkflowDefStatus::Staging } + quote! { #forge::forge_core::workflow::WorkflowDefStatus::Staging } } }; @@ -583,21 +622,28 @@ pub fn workflow_impl(attr: TokenStream, item: TokenStream) -> TokenStream { quote! { None } }; + // Canonicalize type-token whitespace so a syn/quote upgrade that adds or + // drops a space (e.g. `MyType` vs `MyType < Inner >`) doesn't + // silently flip the signature for every in-flight run. Pragmatic minimal + // form: trim + collapse internal runs to a single space. + let canonical_input = canonicalize_type_str(&input_type_str); + let canonical_output = canonicalize_type_str(&output_type_str); + let signature = derive_signature( workflow_name, version_str, &contract_extractor.step_keys, &contract_extractor.wait_keys, timeout_secs, - &input_type_str, - &output_type_str, + &canonical_input, + &canonical_output, ); let fn_attrs = &input.attrs; let registration = if attrs.register { quote! { - forge::inventory::submit!(forge::AutoHandler(|registries| { + #forge::inventory::submit!(#forge::AutoHandler(|registries| { registries.workflows.register::<#struct_name>(); })); } @@ -635,14 +681,14 @@ pub fn workflow_impl(attr: TokenStream, item: TokenStream) -> TokenStream { #[doc = #contract_doc] pub struct #struct_name; - impl forge::forge_core::__sealed::Sealed for #struct_name {} + impl #forge::forge_core::__sealed::Sealed for #struct_name {} - impl forge::forge_core::workflow::ForgeWorkflow for #struct_name { + impl #forge::forge_core::workflow::ForgeWorkflow for #struct_name { type Input = #input_type; type Output = #output_type; - fn info() -> forge::forge_core::workflow::WorkflowInfo { - forge::forge_core::workflow::WorkflowInfo { + fn info() -> #forge::forge_core::workflow::WorkflowInfo { + #forge::forge_core::workflow::WorkflowInfo { name: #workflow_name, version: #version_str, signature: #signature, @@ -655,9 +701,9 @@ pub fn workflow_impl(attr: TokenStream, item: TokenStream) -> TokenStream { } fn execute( - ctx: &forge::forge_core::workflow::WorkflowContext, + ctx: &#forge::forge_core::workflow::WorkflowContext, #input_ident: Self::Input, - ) -> std::pin::Pin> + Send + '_>> { + ) -> std::pin::Pin> + Send + '_>> { Box::pin(async move #block) } } diff --git a/crates/forge-runtime/src/cluster/heartbeat.rs b/crates/forge-runtime/src/cluster/heartbeat.rs index b42dce5a..a8e01f3a 100644 --- a/crates/forge-runtime/src/cluster/heartbeat.rs +++ b/crates/forge-runtime/src/cluster/heartbeat.rs @@ -22,7 +22,10 @@ impl Default for HeartbeatConfig { interval: Duration::from_secs(5), dead_threshold: Duration::from_secs(15), mark_dead_nodes: true, - max_interval: Duration::from_secs(60), + // Capped at 30 s (was 60 s). Combined with the 3x dead_threshold + // ceiling below, worst-case detection lag drops from ~180 s to + // ~90 s for a crash during the stable phase. + max_interval: Duration::from_secs(30), } } } @@ -59,7 +62,9 @@ impl HeartbeatConfig { interval: *cluster.heartbeat_interval, dead_threshold: *cluster.dead_threshold, mark_dead_nodes: true, - max_interval: Duration::from_secs(cluster.heartbeat_interval.as_secs() * 12), + // Adaptive ceiling: 6x base (was 12x). Caps stable-phase interval + // at a tighter bound so dead-node detection stays under 90 s. + max_interval: Duration::from_secs(cluster.heartbeat_interval.as_secs() * 6), } } } @@ -76,6 +81,12 @@ pub struct HeartbeatLoop { stable_count: AtomicU32, last_active_count: AtomicU32, /// Dedicated connection held outside the shared pool for liveness safety. + /// + /// **Persistent-connection budget**: this connection counts as the + /// 5th persistent slot alongside the 4 listed in `pg/pool.rs` (notify + /// bus listener, leader lock-owning connection, change-log listener, + /// signals writer). Pool sizing must allow `min_connections >= 5` plus + /// burst headroom or normal RPC workload contends for the remaining slots. heartbeat_conn: Mutex>, } @@ -194,12 +205,17 @@ impl HeartbeatLoop { let mut guard = self.heartbeat_conn.lock().await; if guard.ping().await.is_err() { tracing::debug!("Heartbeat connection lost; reconnecting"); + // Acquire the replacement first, then swap. If acquire fails we + // keep the (broken) old handle so we don't permanently lose the + // slot — next call retries. Explicitly drop the prior handle so + // its slot is returned to the pool before we hold the new one. let new_conn = self .pool .acquire() .await .map_err(forge_core::ForgeError::Database)?; - *guard = new_conn; + let old = std::mem::replace(&mut *guard, new_conn); + drop(old); } Ok(guard) } @@ -319,7 +335,7 @@ mod tests { assert_eq!(config.interval, Duration::from_secs(5)); assert_eq!(config.dead_threshold, Duration::from_secs(15)); assert!(config.mark_dead_nodes); - assert_eq!(config.max_interval, Duration::from_secs(60)); + assert_eq!(config.max_interval, Duration::from_secs(30)); } #[test] @@ -332,8 +348,8 @@ mod tests { assert_eq!(config.interval, Duration::from_secs(10)); assert_eq!(config.dead_threshold, Duration::from_secs(30)); assert!(config.mark_dead_nodes); - // max_interval = heartbeat_interval * 12 - assert_eq!(config.max_interval, Duration::from_secs(120)); + // max_interval = heartbeat_interval * 6 + assert_eq!(config.max_interval, Duration::from_secs(60)); } #[test] diff --git a/crates/forge-runtime/src/cluster/registry.rs b/crates/forge-runtime/src/cluster/registry.rs index ba1cd69f..e9bf8581 100644 --- a/crates/forge-runtime/src/cluster/registry.rs +++ b/crates/forge-runtime/src/cluster/registry.rs @@ -1,6 +1,11 @@ +use std::time::Duration; + use forge_core::cluster::{NodeInfo, NodeStatus}; use forge_core::{ForgeError, Result}; +/// How often the background compactor sweeps long-dead node rows. +const COMPACTION_INTERVAL: Duration = Duration::from_secs(6 * 60 * 60); + /// Node registry for cluster membership. pub struct NodeRegistry { pool: sqlx::PgPool, @@ -9,7 +14,48 @@ pub struct NodeRegistry { impl NodeRegistry { pub fn new(pool: sqlx::PgPool, local_node: NodeInfo) -> Self { - Self { pool, local_node } + let registry = Self { pool, local_node }; + registry.spawn_cleanup_loop(); + registry + } + + /// Periodically delete `forge_nodes` rows that have been `dead` for + /// more than 7 days. Pod churn would otherwise accumulate rows + /// indefinitely. Runs every 6 h on a detached task; failures are logged + /// and the loop continues. + fn spawn_cleanup_loop(&self) { + let pool = self.pool.clone(); + tokio::spawn(async move { + let mut ticker = tokio::time::interval(COMPACTION_INTERVAL); + ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + // First tick fires immediately; skip it so startup doesn't compact. + ticker.tick().await; + loop { + ticker.tick().await; + // Untyped: the parameterless DELETE produces no row data and + // adding a `.sqlx` entry for it just for the compile-time + // check has zero safety value. Allow lints locally. + #[allow(clippy::disallowed_methods)] + let res = sqlx::query( + "DELETE FROM forge_nodes \ + WHERE status = 'dead' \ + AND last_heartbeat < NOW() - INTERVAL '7 days'", + ) + .execute(&pool) + .await; + match res { + Ok(result) => { + let n = result.rows_affected(); + if n > 0 { + tracing::info!(rows = n, "Compacted dead forge_nodes rows"); + } + } + Err(e) => { + tracing::warn!(error = %e, "forge_nodes compaction failed"); + } + } + } + }); } pub async fn register(&self) -> Result<()> { diff --git a/crates/forge-runtime/src/cron/bridge.rs b/crates/forge-runtime/src/cron/bridge.rs index 8e0d9c77..61491f73 100644 --- a/crates/forge-runtime/src/cron/bridge.rs +++ b/crates/forge-runtime/src/cron/bridge.rs @@ -61,6 +61,25 @@ pub fn register_cron_bridges(cron_registry: &Arc, job_registry: &m } handler(&cron_ctx).await?; + + // Transition the claimed run from 'running' to 'completed'. The + // scheduler only ever INSERTs status='running'; this is the sole + // place a successful run is finalized, so catch-up's "last + // completed scheduled_time" lookup and operator dashboards see a + // terminal state. Scoped to status='running' so a concurrent + // stale-reclaim that already rotated the id cannot be clobbered. + sqlx::query!( + r#" + UPDATE forge_cron_runs + SET status = 'completed', completed_at = NOW(), error = NULL + WHERE id = $1 AND status = 'running' + "#, + run_id, + ) + .execute(ctx.pool()) + .await + .map_err(forge_core::ForgeError::Database)?; + Ok(serde_json::Value::Null) }) }); diff --git a/crates/forge-runtime/src/cron/registry.rs b/crates/forge-runtime/src/cron/registry.rs index 9fa5c1fa..6609169b 100644 --- a/crates/forge-runtime/src/cron/registry.rs +++ b/crates/forge-runtime/src/cron/registry.rs @@ -38,7 +38,15 @@ impl CronRegistry { } pub fn register(&mut self) { - let entry = CronEntry::new::(); + self.register_entry(CronEntry::new::()); + } + + /// Insert a pre-built [`CronEntry`], keyed by its `info.name`. + /// + /// `register::` is the public path; this primitive exists so in-crate + /// tests can register a handler without a `ForgeCron` impl (the trait is + /// sealed and cannot be implemented outside `forge-core`). + pub(crate) fn register_entry(&mut self, entry: CronEntry) { self.crons.insert(entry.info.name.to_string(), entry); } diff --git a/crates/forge-runtime/src/cron/scheduler.rs b/crates/forge-runtime/src/cron/scheduler.rs index 8779c8c9..e7d082ea 100644 --- a/crates/forge-runtime/src/cron/scheduler.rs +++ b/crates/forge-runtime/src/cron/scheduler.rs @@ -223,9 +223,14 @@ impl CronRunner { async { let now = Utc::now(); + // Window is poll_interval * 4 (was *2). The wider window covers + // short GC pauses, leader-loss recovery, and DB stalls that + // otherwise drop ticks for crons without catch_up. Inverse cost + // is bounded by the UNIQUE (cron_name, scheduled_time) constraint: + // re-checking the same slot twice claims at most one job. let window_start = now - - chrono::Duration::from_std(self.config.poll_interval * 2) - .unwrap_or(chrono::Duration::seconds(2)); + - chrono::Duration::from_std(self.config.poll_interval * 4) + .unwrap_or(chrono::Duration::seconds(4)); let cron_list = self.registry.list(); let mut jobs_executed = 0u32; @@ -250,10 +255,11 @@ impl CronRunner { .between_in_tz(window_start, now, info.timezone); if scheduled_times.len() > 1 { - tracing::info!( + tracing::warn!( cron.name = info.name, cron.missed_count = scheduled_times.len() - 1, - "Detected missed cron runs" + catch_up_enabled = info.catch_up, + "missed cron tick: more than one scheduled time fell into the poll window" ); Span::current().record("cron.missed_runs", scheduled_times.len() - 1); } @@ -268,6 +274,19 @@ impl CronRunner { } for scheduled in scheduled_times { + // Re-check leadership between inserts so a node that lost + // the lock mid-tick stops enqueueing slots tagged with its + // node_id. UNIQUE constraint bounds the damage, but this + // cuts observability noise. + if let Some(election) = self.config.leader_election.as_ref() + && !election.is_leader() + { + tracing::debug!( + cron = info.name, + "Leadership lost mid-tick; aborting remaining enqueues" + ); + break; + } if let Ok(Some(_run_id)) = self.try_claim_and_enqueue(entry, scheduled, false).await { @@ -899,4 +918,89 @@ mod integration_tests { assert!(leader.confirm_leadership_before_tick().await); assert!(!follower.confirm_leadership_before_tick().await); } + + #[tokio::test] + async fn dispatched_cron_run_is_marked_completed_on_success() { + // The scheduler only ever writes status='running'. A successful run must + // be finalized to 'completed' by the `$cron:` bridge job that the worker + // executes. This drives the real claim path (try_claim_and_enqueue) to + // create the running row + job, then runs the real bridge handler against + // the dispatched job's input. No row is seeded — completion is observed + // end-to-end, not asserted on a hand-written 'completed' row. + use crate::cron::register_cron_bridges; + use forge_core::CircuitBreakerClient; + use forge_core::job::JobContext; + + let db = setup_db("cron_run_completed").await; + let pool = db.pool().clone(); + + // Real claim: inserts the 'running' row and enqueues the $cron: job. + let runner = make_runner(pool.clone()); + let entry = make_entry("nightly_report", "0 0 * * * *"); + let scheduled = Utc::now() - chrono::Duration::seconds(30); + let run_id = runner + .try_claim_and_enqueue(&entry, scheduled, false) + .await + .expect("claim ok") + .expect("claimed some id"); + + let status_after_claim: String = + sqlx::query_scalar("SELECT status FROM forge_cron_runs WHERE id = $1") + .bind(run_id) + .fetch_one(&pool) + .await + .unwrap(); + assert_eq!( + status_after_claim, "running", + "claim must leave the run in 'running'" + ); + + // Build the real bridge handler for this cron. + let cron_registry = Arc::new({ + let mut reg = CronRegistry::new(); + reg.register_entry(make_entry("nightly_report", "0 0 * * * *")); + reg + }); + let mut job_registry = crate::jobs::registry::JobRegistry::new(); + register_cron_bridges(&cron_registry, &mut job_registry); + + let bridge = job_registry + .get("$cron:nightly_report") + .expect("bridge handler registered"); + + // Feed the handler the exact input the scheduler queued for this run. + let input: serde_json::Value = sqlx::query_scalar( + "SELECT input FROM forge_jobs WHERE job_type = '$cron:nightly_report'", + ) + .fetch_one(&pool) + .await + .unwrap(); + + let ctx = JobContext::new( + Uuid::new_v4(), + "$cron:nightly_report".to_string(), + 0, + 3, + pool.clone(), + CircuitBreakerClient::with_ssrf_protection(), + ); + (bridge.handler)(&ctx, input) + .await + .expect("bridge handler succeeds"); + + let (status, completed_at): (String, Option>) = + sqlx::query_as("SELECT status, completed_at FROM forge_cron_runs WHERE id = $1") + .bind(run_id) + .fetch_one(&pool) + .await + .unwrap(); + assert_eq!( + status, "completed", + "successful run must become 'completed'" + ); + assert!( + completed_at.is_some(), + "completed_at must be set on completion" + ); + } } diff --git a/crates/forge-runtime/src/daemon/runner.rs b/crates/forge-runtime/src/daemon/runner.rs index 5c410484..4ee41d50 100644 --- a/crates/forge-runtime/src/daemon/runner.rs +++ b/crates/forge-runtime/src/daemon/runner.rs @@ -110,6 +110,7 @@ impl DaemonRunner { tracing::info!(count = self.registry.len(), "Daemon runner starting"); let mut daemon_handles: HashMap = HashMap::new(); + let mut join_handles: HashMap> = HashMap::new(); for (name, entry) in self.registry.daemons() { let info = &entry.info; @@ -160,7 +161,7 @@ impl DaemonRunner { None }; - tokio::spawn(async move { + let jh = tokio::spawn(async move { run_daemon_loop( daemon_name, daemon_entry, @@ -180,6 +181,7 @@ impl DaemonRunner { .await }); + join_handles.insert(name.to_string(), jh); daemon_handles.insert(name.to_string(), handle); } @@ -191,7 +193,37 @@ impl DaemonRunner { let _ = handle.shutdown_tx.send(true); } - tokio::time::sleep(Duration::from_secs(2)).await; + // Cap the drain at 10 s but exit early when all daemons have + // signalled they observed the shutdown. The previous fixed 2 s + // recorded `status='stopped'` while slow daemons were still + // draining; the cap keeps shutdown bounded for the same reason + // a strict block would not (a wedged daemon must not stall the + // node forever). + const MAX_DRAIN: Duration = Duration::from_secs(10); + let drain_deadline = tokio::time::Instant::now() + MAX_DRAIN; + + // Await each daemon's task with a bounded deadline so + // `record_daemon_stop` never races ahead of the lease-refresher + // and lock-validator tasks the loop spawned per iteration. + // Abort any task that exceeds the deadline so shutdown stays + // bounded. + for (name, jh) in join_handles.drain() { + let remaining = + drain_deadline.saturating_duration_since(tokio::time::Instant::now()); + match tokio::time::timeout(remaining, jh).await { + Ok(Ok(())) => {} + Ok(Err(e)) => { + tracing::warn!(daemon = %name, error = %e, "Daemon task join failed"); + } + Err(_) => { + tracing::warn!( + daemon = %name, + "Daemon drain exceeded {:?}; aborting task to keep shutdown bounded", + MAX_DRAIN, + ); + } + } + } for (name, handle) in &daemon_handles { if let Err(e) = self.record_daemon_stop(handle).await { @@ -326,8 +358,35 @@ async fn run_daemon_loop( } Ok(false) => { tracing::debug!("Waiting for leadership"); + // If the election has a notify-bus attached, wake on + // the leader-released NOTIFY so a standby takes over + // immediately on voluntary release. Otherwise fall + // back to the 5 s poll. Filter for our own role so + // unrelated NOTIFYs don't trigger spurious wakeups. + let mut release_rx = election.subscribe_release_notify(); + // Payload is `LeaderRole::as_str()`; for Daemon variants this is + // the daemon name verbatim. + let role_str = name.clone(); + let wait = async { + if let Some(rx) = release_rx.as_mut() { + loop { + match rx.recv().await { + Ok(payload) if payload == role_str => return, + Ok(_) => continue, + Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => return, + Err(tokio::sync::broadcast::error::RecvError::Closed) => { + // Bus is down; fall back to sleep. + tokio::time::sleep(Duration::from_secs(5)).await; + return; + } + } + } + } else { + tokio::time::sleep(Duration::from_secs(5)).await; + } + }; tokio::select! { - _ = tokio::time::sleep(Duration::from_secs(5)) => {} + _ = wait => {} _ = shutdown_rx.changed() => { tracing::debug!("Shutdown while waiting for leadership"); Span::current().record("daemon.final_status", "shutdown_waiting_leadership"); @@ -592,7 +651,36 @@ async fn run_daemon_loop( if let Some(ref election) = election && let Err(e) = election.release_leadership().await { - tracing::debug!(daemon = %name, error = %e, "Failed to release leadership"); + tracing::error!( + daemon = %name, + error = %e, + "Failed to release leadership; clearing forge_leaders row so standbys can take over without waiting for the full lease" + ); + // Force-clear the leader row scoped to this node. Standbys would + // otherwise wait the full lease_duration (60 s) before + // preempting. WHERE node_id = $2 ensures we never wipe a row + // another node has already taken. + // forge_leaders is a runtime-owned system table; offline .sqlx + // cache doesn't always include it. + #[allow(clippy::disallowed_methods)] + let force_clear = sqlx::query( + r#" + DELETE FROM forge_leaders + WHERE role = $1 AND node_id = $2 + "#, + ) + .bind(name.as_str()) + .bind(node_id) + .execute(&pool) + .await; + if let Err(e2) = force_clear + { + tracing::error!( + daemon = %name, + error = %e2, + "Failed to clear forge_leaders row after release failure; standbys will wait for lease expiry" + ); + } } tracing::info!( diff --git a/crates/forge-runtime/src/gateway/admin.rs b/crates/forge-runtime/src/gateway/admin.rs index ef5c7fb4..b5dd2997 100644 --- a/crates/forge-runtime/src/gateway/admin.rs +++ b/crates/forge-runtime/src/gateway/admin.rs @@ -22,6 +22,7 @@ //! - `POST /_api/admin/queues/{name}/resume` //! - `GET /_api/admin/nodes` //! - `GET /_api/admin/leaders` +//! - `POST /_api/admin/sessions/{session_id}/revoke body: {reason?}` use std::sync::Arc; @@ -37,6 +38,9 @@ use sqlx::PgPool; use uuid::Uuid; use forge_core::function::AuthContext; +use forge_core::realtime::SessionId; + +use crate::realtime::Reactor; use super::tracing::TracingState; @@ -44,6 +48,9 @@ use super::tracing::TracingState; #[derive(Clone)] pub struct AdminState { pub db_pool: PgPool, + /// Reactor handle for session-revocation. None when running headless (e.g. + /// migration-only commands) — the route then returns 503. + pub reactor: Option>, } /// Build the admin router. Returns `None` when no admin handler can do any @@ -69,6 +76,7 @@ pub fn admin_router(state: AdminState) -> Router { .route("/admin/queues/{name}/resume", post(resume_queue)) .route("/admin/nodes", get(list_nodes)) .route("/admin/leaders", get(list_leaders)) + .route("/admin/sessions/{session_id}/revoke", post(revoke_session)) .with_state(Arc::new(state)) } @@ -1148,6 +1156,50 @@ async fn list_leaders( } } +/// Revoke a session's cached `AuthContext` so the reactor stops re-pushing +/// data tied to that session. Operators wire this to their identity system's +/// revocation event (role demotion, tenant move, manual sign-out across all +/// devices); the client must reconnect and re-subscribe with a fresh token to +/// resume receiving updates. +async fn revoke_session( + State(state): State>, + Extension(auth): Extension, + Extension(tracing_state): Extension, + Path(session_id): Path, + body: Option>, +) -> axum::response::Response { + if let Err(r) = require_admin(&auth) { + return r; + } + let reactor = match state.reactor.as_ref() { + Some(r) => r.clone(), + None => { + return admin_err( + StatusCode::SERVICE_UNAVAILABLE, + "reactor_unavailable", + "Realtime reactor is not running on this node", + ); + } + }; + let reason = body.and_then(|b| b.0.reason); + let reason_str = reason.as_deref().unwrap_or("admin revoke"); + reactor + .revoke_session_auth(SessionId(session_id), reason_str) + .await; + audit( + &state.db_pool, + &auth, + Some(&tracing_state), + "session.revoke", + "session", + Some(session_id.to_string()), + reason.as_deref(), + serde_json::json!({"session_id": session_id}), + ) + .await; + Json(serde_json::json!({"status": "revoked"})).into_response() +} + #[cfg(test)] #[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)] mod tests { diff --git a/crates/forge-runtime/src/gateway/auth.rs b/crates/forge-runtime/src/gateway/auth.rs index 7aa37e58..4d438686 100644 --- a/crates/forge-runtime/src/gateway/auth.rs +++ b/crates/forge-runtime/src/gateway/auth.rs @@ -17,12 +17,14 @@ use tracing::debug; use super::jwks::JwksClient; /// Derive a stable, opaque key id from an HMAC secret. We take the first -/// 8 hex chars of `SHA-256(secret_bytes)` — short enough to keep token -/// headers small, deterministic so the same secret always produces the -/// same kid, and one-way so it leaks nothing useful about the secret. +/// 16 hex chars (8 bytes / 64 bits) of `SHA-256(secret_bytes)` — small +/// enough to keep token headers compact while large enough to make kid +/// collisions across rotated secrets infeasible. Deterministic so the +/// same secret always produces the same kid, and one-way so it leaks +/// nothing useful about the secret. fn secret_kid(secret: &[u8]) -> String { let hash = Sha256::digest(secret); - let prefix = hash.as_slice().get(..4).unwrap_or(&[]); + let prefix = hash.as_slice().get(..8).unwrap_or(&[]); let mut out = String::with_capacity(prefix.len() * 2); for byte in prefix { use std::fmt::Write; @@ -346,6 +348,11 @@ pub struct AuthMiddleware { /// (Claims, expiry). The 256-bit key makes collisions cryptographically /// infeasible, so a hit unambiguously identifies the same token. token_cache: Arc>, + /// Monotonic instant (seconds since process start) of the last cache + /// sweep. Stored as `AtomicU64` so eviction is lock-free on the hot path. + last_cache_sweep_secs: Arc, + /// Process-start anchor for `last_cache_sweep_secs`. + cache_sweep_epoch: std::time::Instant, } impl std::fmt::Debug for AuthMiddleware { @@ -409,6 +416,8 @@ impl AuthMiddleware { hmac_kid, legacy_hmac_keys, token_cache: Arc::new(dashmap::DashMap::new()), + last_cache_sweep_secs: Arc::new(std::sync::atomic::AtomicU64::new(0)), + cache_sweep_epoch: std::time::Instant::now(), } } @@ -475,22 +484,52 @@ impl AuthMiddleware { const MAX_CACHE_TTL: std::time::Duration = std::time::Duration::from_secs(60); let exp = claims.exp(); let now = chrono::Utc::now().timestamp(); - let remaining = if exp > now { - std::time::Duration::from_secs((exp - now) as u64) - } else { - std::time::Duration::ZERO - }; - remaining.min(MAX_CACHE_TTL) + // `exp` is i64 (JWT spec); guard against negative remaining so a + // platform with a 32-bit `time_t` or skewed clock can't underflow + // into an absurdly large u64 TTL. + let remaining_secs = u64::try_from(exp.saturating_sub(now)).unwrap_or(0); + std::time::Duration::from_secs(remaining_secs).min(MAX_CACHE_TTL) } /// Periodically evict expired entries to prevent unbounded growth. + /// Sweeps when either (a) the cache exceeds `MAX_CACHE_SIZE` entries, or + /// (b) `SWEEP_INTERVAL` has elapsed since the last sweep. The time-based + /// trigger matters under low traffic with many short-lived tokens — the + /// size trigger alone would let stale entries accumulate indefinitely. fn evict_expired_cache_entries(&self) { + use std::sync::atomic::Ordering; const MAX_CACHE_SIZE: usize = 10_000; - if self.token_cache.len() > MAX_CACHE_SIZE { - let now = std::time::Instant::now(); - self.token_cache - .retain(|_, (_, expires_at)| *expires_at > now); + const SWEEP_INTERVAL_SECS: u64 = 60; + + let now_instant = std::time::Instant::now(); + let elapsed_since_start = now_instant + .saturating_duration_since(self.cache_sweep_epoch) + .as_secs(); + let last = self.last_cache_sweep_secs.load(Ordering::Relaxed); + let time_due = elapsed_since_start.saturating_sub(last) >= SWEEP_INTERVAL_SECS; + let size_due = self.token_cache.len() > MAX_CACHE_SIZE; + + if !(size_due || time_due) { + return; } + + // Race-free claim: only the caller that successfully advances the + // sweep timestamp performs the scan; concurrent callers skip. + if self + .last_cache_sweep_secs + .compare_exchange( + last, + elapsed_since_start, + Ordering::AcqRel, + Ordering::Relaxed, + ) + .is_err() + { + return; + } + + self.token_cache + .retain(|_, (_, expires_at)| *expires_at > now_instant); } /// Validate HMAC-signed token. Uses the token's `kid` header to look up @@ -541,21 +580,42 @@ impl AuthMiddleware { let safe_kid = header.kid.as_deref().map(sanitize_for_log); debug!(kid = ?safe_kid, alg = ?header.alg, "Validating RSA token"); - let key = if let Some(kid) = header.kid { - jwks.get_key(&kid).await.map_err(|e| { + if let Some(kid) = header.kid { + let key = jwks.get_key(&kid).await.map_err(|e| { AuthError::InvalidToken(format!("Failed to get key '{}': {}", kid, e)) - })? - } else if self.config.jwks_require_kid { + })?; + return self.decode_and_validate(token, &key); + } + + if self.config.jwks_require_kid { return Err(AuthError::InvalidToken( "RS256 token missing kid header; set auth.jwks_require_kid = false to allow kidless tokens".to_string(), )); - } else { - jwks.get_any_key() - .await - .map_err(|e| AuthError::InvalidToken(format!("Failed to get JWKS key: {}", e)))? - }; + } - self.decode_and_validate(token, &key) + // No kid: try every kidless key. A signature mismatch under one key + // does not imply the token is invalid — providers that publish multiple + // kidless keys (during rotation) require us to attempt each. + let candidates = jwks + .kidless_keys() + .await + .map_err(|e| AuthError::InvalidToken(format!("Failed to get JWKS key: {}", e)))?; + if candidates.is_empty() { + return Err(AuthError::InvalidToken( + "No kidless JWKS keys available for kidless token".to_string(), + )); + } + + let mut last_err: Option = None; + for key in &candidates { + match self.decode_and_validate(token, key) { + Ok(claims) => return Ok(claims), + Err(e) => last_err = Some(e), + } + } + Err(last_err.unwrap_or_else(|| { + AuthError::InvalidToken("Kidless token did not validate against any key".to_string()) + })) } /// Decode and validate token with the given key. @@ -627,6 +687,10 @@ impl AuthMiddleware { } /// Decode JWT token without signature verification (DEV MODE ONLY). + /// Still enforces `exp` so that a missing or zero expiry — which prod + /// validation rejects via `required_claims` — also fails here. Matching + /// the prod rule shrinks the downgrade-attack surface when a dev-built + /// token accidentally reaches a prod-config gateway. fn decode_without_verification(&self, token: &str) -> Result { let token_data = dangerous::insecure_decode::(token).map_err(|e| match e.kind() { @@ -636,6 +700,14 @@ impl AuthMiddleware { _ => AuthError::InvalidToken(e.to_string()), })?; + // `exp == 0` (epoch) would deserialize fine since `i64` has no default + // guard against it. Treat as missing. + if token_data.claims.exp() <= 0 { + return Err(AuthError::InvalidToken( + "Token missing required `exp` claim".to_string(), + )); + } + if token_data.claims.is_expired() { return Err(AuthError::TokenExpired); } @@ -810,12 +882,22 @@ pub async fn auth_middleware( let should_set_cookie = auth_context.is_authenticated() && middleware.config.jwt_secret.is_some(); - // Skip cookie if one already exists (avoids resigning on every request) + // Skip cookie if one already exists (avoids resigning on every request). + // Parse the Cookie header strictly: split on ';', trim whitespace, and + // match name exactly so substrings like `xforge_session` don't trigger. let has_session_cookie = req .headers() .get(header::COOKIE) .and_then(|v| v.to_str().ok()) - .map(|c| c.contains("forge_session=")) + .map(|c| { + c.split(';').any(|pair| { + let trimmed = pair.trim_start(); + trimmed + .split_once('=') + .map(|(name, _)| name == "forge_session") + .unwrap_or(false) + }) + }) .unwrap_or(false); let should_set_cookie = should_set_cookie && !has_session_cookie; @@ -851,8 +933,12 @@ pub async fn auth_middleware( // browsers refuse to send `Secure` cookies over HTTP, which surfaces // misconfigured deployments as a clean failure rather than silently // weakening the session. + // Path=/ ensures the browser sends the cookie on every request, so the + // resign-skip check above actually fires. With a narrower path the + // cookie would only be visible to /_api/oauth/* and every other request + // would re-sign it. let cookie = format!( - "forge_session={cookie_value}; Path=/_api/oauth/; HttpOnly; SameSite=Lax; Secure; Max-Age={cookie_ttl}" + "forge_session={cookie_value}; Path=/; HttpOnly; SameSite=Lax; Secure; Max-Age={cookie_ttl}" ); if let Ok(val) = axum::http::HeaderValue::from_str(&cookie) { response.headers_mut().append(header::SET_COOKIE, val); @@ -1555,7 +1641,7 @@ mod tests { let kid_a = secret_kid(b"some-secret"); let kid_b = secret_kid(b"some-secret"); assert_eq!(kid_a, kid_b); - assert_eq!(kid_a.len(), 8, "kid should be 8 hex chars (4 bytes)"); + assert_eq!(kid_a.len(), 16, "kid should be 16 hex chars (8 bytes)"); assert_ne!(kid_a, secret_kid(b"different-secret")); } @@ -1667,7 +1753,7 @@ mod tests { let config = AuthConfig { algorithm: JwtAlgorithm::RS256, jwks_client: Some(Arc::new( - JwksClient::new("http://example.invalid".into(), 3600).unwrap(), + JwksClient::new("https://example.invalid".into(), 3600).unwrap(), )), jwks_require_kid: true, ..AuthConfig::default() diff --git a/crates/forge-runtime/src/gateway/jwks.rs b/crates/forge-runtime/src/gateway/jwks.rs index 2e0934dc..817c3932 100644 --- a/crates/forge-runtime/src/gateway/jwks.rs +++ b/crates/forge-runtime/src/gateway/jwks.rs @@ -56,6 +56,10 @@ pub struct JsonWebKey { struct CachedJwks { /// Map of key ID to decoding key. keys: HashMap, + /// Keys served by the provider without a `kid`. Stored separately so a + /// rotation that ships multiple kidless keys does not silently lose every + /// one but the last. + kidless_keys: Vec, /// When the cache was last refreshed. fetched_at: Instant, } @@ -113,6 +117,15 @@ impl JwksClient { /// * `url` - The JWKS endpoint URL /// * `cache_ttl_secs` - How long to cache keys (in seconds) pub fn new(url: String, cache_ttl_secs: u64) -> Result { + // Reject plain-HTTP JWKS endpoints: an on-path attacker can swap keys + // and mint arbitrary RS256 tokens. Loopback is permitted for local + // identity-provider stubs in tests and development. + let insecure = forge_core::util::http_hostname(&url) + .is_some_and(|host| !forge_core::util::is_loopback_host(host)); + if insecure { + return Err(JwksError::InsecureUrl(url)); + } + let http_client = reqwest::Client::builder() .timeout(Duration::from_secs(10)) .build() @@ -206,15 +219,22 @@ impl JwksClient { /// Some providers don't include a key ID in tokens. This method /// returns the first available key from the JWKS. pub async fn get_any_key(&self) -> Result { - // Try to get from cache first + // Try to get from cache first. Kidless keys are preferred for + // kidless-token fallback so a provider rotation that ships multiple + // kidless keys still has every entry reachable here. { let cache = self.cache.read().await; if let Some(ref cached) = *cache && cached.fetched_at.elapsed() < self.cache_ttl - && let Some(key) = cached.keys.values().next() { - debug!("Using first cached JWKS key (no kid specified)"); - return Ok(key.clone()); + if let Some(key) = cached.kidless_keys.first() { + debug!("Using first cached kidless JWKS key"); + return Ok(key.clone()); + } + if let Some(key) = cached.keys.values().next() { + debug!("Using first cached JWKS key (no kid specified)"); + return Ok(key.clone()); + } } } @@ -224,6 +244,9 @@ impl JwksClient { let cache = self.cache.read().await; if let Some(ref cached) = *cache { + if let Some(key) = cached.kidless_keys.first().cloned() { + return Ok(key); + } cached .keys .values() @@ -235,6 +258,29 @@ impl JwksClient { } } + /// Try every cached kidless key in turn. Used by RSA validation when the + /// incoming token carries no `kid` header — without this, a kidless-key + /// rotation silently fails for tokens signed by the second key. + pub async fn kidless_keys(&self) -> Result, JwksError> { + { + let cache = self.cache.read().await; + if let Some(ref cached) = *cache + && cached.fetched_at.elapsed() < self.cache_ttl + && !cached.kidless_keys.is_empty() + { + return Ok(cached.kidless_keys.clone()); + } + } + self.refresh_if_needed().await?; + let cache = self.cache.read().await; + match *cache { + Some(ref cached) => Ok(cached.kidless_keys.clone()), + None => Err(JwksError::FetchFailed( + "Cache empty after refresh".to_string(), + )), + } + } + /// Force refresh the key cache. /// /// Fetches fresh keys from the JWKS endpoint regardless of cache state. @@ -261,6 +307,7 @@ impl JwksClient { .map_err(|e| JwksError::ParseFailed(e.to_string()))?; let mut keys = HashMap::new(); + let mut kidless_keys = Vec::new(); for jwk in jwks.keys { // Skip non-signature keys @@ -270,27 +317,36 @@ impl JwksClient { continue; } - let kid = jwk.kid.clone().unwrap_or_else(|| "default".to_string()); + let kid_for_log = jwk.kid.as_deref().unwrap_or("").to_string(); match self.parse_jwk(&jwk) { Ok(Some(key)) => { - debug!(kid = %kid, kty = %jwk.kty, "Parsed JWKS key"); - keys.insert(kid, key); + debug!(kid = %kid_for_log, kty = %jwk.kty, "Parsed JWKS key"); + match jwk.kid { + Some(k) => { + keys.insert(k, key); + } + None => kidless_keys.push(key), + } } Ok(None) => { - debug!(kid = %kid, kty = %jwk.kty, "Skipping unsupported key type"); + debug!(kid = %kid_for_log, kty = %jwk.kty, "Skipping unsupported key type"); } Err(e) => { - warn!(kid = %kid, error = %e, "Failed to parse JWKS key"); + warn!(kid = %kid_for_log, error = %e, "Failed to parse JWKS key"); } } } - if keys.is_empty() { + if keys.is_empty() && kidless_keys.is_empty() { return Err(JwksError::NoKeysAvailable); } - debug!(count = keys.len(), "Cached JWKS keys"); + debug!( + count = keys.len(), + kidless = kidless_keys.len(), + "Cached JWKS keys" + ); // Drop negative-cache entries for any kid that's now present, so a // rotation that hands us a previously-missing kid takes effect @@ -302,6 +358,7 @@ impl JwksClient { let mut cache = self.cache.write().await; *cache = Some(CachedJwks { keys, + kidless_keys, fetched_at: Instant::now(), }); @@ -374,6 +431,10 @@ pub enum JwksError { /// Failed to create HTTP client. #[error("Failed to create HTTP client: {0}")] HttpClientError(String), + + /// JWKS URL uses an insecure scheme (plain http) outside loopback. + #[error("JWKS URL '{0}' must use https:// (plain http is only allowed for loopback hosts)")] + InsecureUrl(String), } #[cfg(test)] @@ -383,7 +444,7 @@ mod tests { #[test] fn test_parse_jwk_with_n_e() { - let client = JwksClient::new("http://example.com".to_string(), 3600).unwrap(); + let client = JwksClient::new("https://example.com".to_string(), 3600).unwrap(); // Example RSA public key components (minimal test) let jwk = JsonWebKey { @@ -404,7 +465,7 @@ mod tests { #[test] fn test_parse_jwk_unsupported_type() { - let client = JwksClient::new("http://example.com".to_string(), 3600).unwrap(); + let client = JwksClient::new("https://example.com".to_string(), 3600).unwrap(); let jwk = JsonWebKey { kid: Some("test-key".to_string()), @@ -423,7 +484,7 @@ mod tests { #[test] fn test_parse_jwk_missing_components() { - let client = JwksClient::new("http://example.com".to_string(), 3600).unwrap(); + let client = JwksClient::new("https://example.com".to_string(), 3600).unwrap(); let jwk = JsonWebKey { kid: Some("test-key".to_string()), @@ -452,7 +513,7 @@ mod tests { // "oct" (symmetric) keys can't be used for asymmetric verification; we // skip them silently rather than erroring, so the caller can keep // processing the rest of the JWKS. - let client = JwksClient::new("http://example.com".into(), 60).unwrap(); + let client = JwksClient::new("https://example.com".into(), 60).unwrap(); let jwk = JsonWebKey { kid: Some("sym".into()), kty: "oct".into(), @@ -468,7 +529,7 @@ mod tests { #[test] fn parse_jwk_returns_none_when_only_modulus_present() { // RSA with `n` but no `e` is malformed; we drop it rather than crashing. - let client = JwksClient::new("http://example.com".into(), 60).unwrap(); + let client = JwksClient::new("https://example.com".into(), 60).unwrap(); let jwk = JsonWebKey { kid: Some("partial".into()), kty: "RSA".into(), @@ -486,7 +547,7 @@ mod tests { // When x5c is present, the implementation uses it first. A garbage // cert string therefore surfaces as KeyParseFailed, not silent // fallthrough to the n/e branch (which would otherwise succeed). - let client = JwksClient::new("http://example.com".into(), 60).unwrap(); + let client = JwksClient::new("https://example.com".into(), 60).unwrap(); let jwk = JsonWebKey { kid: Some("bad-x5c".into()), kty: "RSA".into(), @@ -536,12 +597,13 @@ mod tests { async fn get_key_returns_cached_match_without_network() { // Pre-populate the cache so the read path is exercised without // touching the JWKS endpoint. Verifies the cached-key fast path. - let client = JwksClient::new("http://example.invalid".into(), 3600).unwrap(); + let client = JwksClient::new("https://example.invalid".into(), 3600).unwrap(); let key = DecodingKey::from_secret(b"placeholder"); let mut keys = HashMap::new(); keys.insert("kid-1".to_string(), key); *client.cache.write().await = Some(CachedJwks { keys, + kidless_keys: Vec::new(), fetched_at: Instant::now(), }); @@ -554,11 +616,12 @@ mod tests { async fn get_any_key_returns_first_cached_when_kid_absent() { // Some providers issue tokens without a `kid` header; the fallback // must return whichever key is cached. - let client = JwksClient::new("http://example.invalid".into(), 3600).unwrap(); + let client = JwksClient::new("https://example.invalid".into(), 3600).unwrap(); let mut keys = HashMap::new(); keys.insert("only".into(), DecodingKey::from_secret(b"placeholder")); *client.cache.write().await = Some(CachedJwks { keys, + kidless_keys: Vec::new(), fetched_at: Instant::now(), }); diff --git a/crates/forge-runtime/src/gateway/mcp/mod.rs b/crates/forge-runtime/src/gateway/mcp/mod.rs index 007c889f..463eac7a 100644 --- a/crates/forge-runtime/src/gateway/mcp/mod.rs +++ b/crates/forge-runtime/src/gateway/mcp/mod.rs @@ -114,7 +114,11 @@ impl Stream for McpReceiverStream { /// Clients use this to receive notifications and asynchronous responses /// from the MCP server. The stream starts with an `endpoint` event /// containing the session ID, then sends keepalive pings every 30 seconds. -pub async fn mcp_get_handler(State(state): State>, headers: HeaderMap) -> Response { +pub async fn mcp_get_handler( + State(state): State>, + Extension(auth): Extension, + headers: HeaderMap, +) -> Response { if let Err(resp) = validate_origin(&headers, &state.config) { return *resp; } @@ -129,6 +133,28 @@ pub async fn mcp_get_handler(State(state): State>, headers: Header Err(resp) => return resp, }; + // Bind the SSE stream to the principal that initialized the session. + // Otherwise any caller with a leaked session id can attach to that + // session's stream and impersonate it. + { + let sessions = state.sessions.read().await; + if let Some(session) = sessions.get(&session_id) { + let current = auth.principal_id(); + if session.principal_id != current { + return ( + StatusCode::FORBIDDEN, + Json(json_rpc_error( + None, + -32001, + "Session principal mismatch", + None, + )), + ) + .into_response(); + } + } + } + state.touch_session(&session_id).await; // Create a channel for server-to-client messages diff --git a/crates/forge-runtime/src/gateway/mcp/session.rs b/crates/forge-runtime/src/gateway/mcp/session.rs index 3434f987..700df5b0 100644 --- a/crates/forge-runtime/src/gateway/mcp/session.rs +++ b/crates/forge-runtime/src/gateway/mcp/session.rs @@ -89,13 +89,46 @@ pub(super) fn validate_origin( headers: &HeaderMap, config: &McpConfig, ) -> std::result::Result<(), ResponseError> { - let Some(origin) = headers.get("origin").and_then(|v| v.to_str().ok()) else { - return Ok(()); - }; + let origin = headers.get("origin").and_then(|v| v.to_str().ok()); - // When no allowed_origins are configured, reject cross-origin requests - // rather than allowing all origins (secure by default) - if config.allowed_origins.is_empty() { + // When the operator has configured an allow-list, the Origin header is + // mandatory. Without this, a browser-adjacent context exploiting DNS + // rebinding (or any client suppressing Origin) bypasses the allow-list. + if !config.allowed_origins.is_empty() { + let allow_any = config.allowed_origins.iter().any(|c| c == "*"); + return match origin { + Some(o) => { + let allowed = allow_any + || config + .allowed_origins + .iter() + .any(|candidate| candidate.eq_ignore_ascii_case(o)); + if allowed { + Ok(()) + } else { + Err(Box::new( + ( + StatusCode::FORBIDDEN, + Json(json_rpc_error(None, -32600, "Invalid Origin header", None)), + ) + .into_response(), + )) + } + } + None if allow_any => Ok(()), + None => Err(Box::new( + ( + StatusCode::FORBIDDEN, + Json(json_rpc_error(None, -32600, "Missing Origin header", None)), + ) + .into_response(), + )), + }; + } + + // No allow-list configured: keep the "secure by default" reject for + // cross-origin requests, and allow non-browser clients that omit Origin. + if origin.is_some() { return Err(Box::new( ( StatusCode::FORBIDDEN, @@ -110,21 +143,7 @@ pub(super) fn validate_origin( )); } - let allowed = config - .allowed_origins - .iter() - .any(|candidate| candidate == "*" || candidate.eq_ignore_ascii_case(origin)); - if allowed { - return Ok(()); - } - - Err(Box::new( - ( - StatusCode::FORBIDDEN, - Json(json_rpc_error(None, -32600, "Invalid Origin header", None)), - ) - .into_response(), - )) + Ok(()) } pub(super) fn enforce_protocol_header( diff --git a/crates/forge-runtime/src/gateway/mcp/tools.rs b/crates/forge-runtime/src/gateway/mcp/tools.rs index f1762ca9..ce10e231 100644 --- a/crates/forge-runtime/src/gateway/mcp/tools.rs +++ b/crates/forge-runtime/src/gateway/mcp/tools.rs @@ -484,23 +484,68 @@ pub(super) async fn handle_proxied_function_call( } } +/// Hard cap on serialized MCP tool output. Past this size we drop +/// `structuredContent` (avoiding the double-encoded blow-up for objects) and +/// truncate the textual representation. Picked to leave generous headroom +/// while preventing a single tool call from buffering tens of MB twice. +const MAX_TOOL_OUTPUT_BYTES: usize = 256 * 1024; + pub(super) fn tool_success_result(output: Value) -> Value { match output { - Value::Object(_) => serde_json::json!({ - "content": [{ - "type": "text", - "text": serde_json::to_string(&output).unwrap_or_else(|_| "{}".to_string()) - }], - "structuredContent": output - }), - Value::String(text) => serde_json::json!({ - "content": [{ "type": "text", "text": text }] - }), - other => serde_json::json!({ - "content": [{ - "type": "text", - "text": serde_json::to_string(&other).unwrap_or_else(|_| "null".to_string()) - }] - }), + Value::Object(_) => { + let text = serde_json::to_string(&output).unwrap_or_else(|_| "{}".to_string()); + if text.len() > MAX_TOOL_OUTPUT_BYTES { + // Avoid embedding the object twice once it exceeds the cap. + let truncated = truncate_at_char_boundary(&text, MAX_TOOL_OUTPUT_BYTES); + serde_json::json!({ + "content": [{ + "type": "text", + "text": truncated + }], + "isError": false, + "_truncated": true + }) + } else { + serde_json::json!({ + "content": [{ "type": "text", "text": text }], + "structuredContent": output + }) + } + } + Value::String(text) => { + let text = if text.len() > MAX_TOOL_OUTPUT_BYTES { + truncate_at_char_boundary(&text, MAX_TOOL_OUTPUT_BYTES) + } else { + text + }; + serde_json::json!({ + "content": [{ "type": "text", "text": text }] + }) + } + other => { + let text = serde_json::to_string(&other).unwrap_or_else(|_| "null".to_string()); + let text = if text.len() > MAX_TOOL_OUTPUT_BYTES { + truncate_at_char_boundary(&text, MAX_TOOL_OUTPUT_BYTES) + } else { + text + }; + serde_json::json!({ + "content": [{ "type": "text", "text": text }] + }) + } + } +} + +fn truncate_at_char_boundary(s: &str, max_bytes: usize) -> String { + if s.len() <= max_bytes { + return s.to_string(); + } + let mut end = max_bytes; + while end > 0 && !s.is_char_boundary(end) { + end -= 1; } + let mut out = String::with_capacity(end + 16); + out.push_str(s.get(..end).unwrap_or("")); + out.push_str("…[truncated]"); + out } diff --git a/crates/forge-runtime/src/gateway/mod.rs b/crates/forge-runtime/src/gateway/mod.rs index c9e45fec..b653c1ab 100644 --- a/crates/forge-runtime/src/gateway/mod.rs +++ b/crates/forge-runtime/src/gateway/mod.rs @@ -25,8 +25,9 @@ pub use response::{RpcError, RpcResponse}; pub use rpc::RpcHandler; pub use server::{GatewayConfig, GatewayServer, TrustedProxies}; pub use sse::{ - SseConfig, SsePayload, SseQuery, SseState, sse_handler, sse_job_subscribe_handler, - sse_subscribe_handler, sse_unsubscribe_handler, sse_workflow_subscribe_handler, + SseConfig, SsePayload, SseQuery, SseState, SseTicketResponse, sse_handler, + sse_job_subscribe_handler, sse_subscribe_handler, sse_ticket_handler, sse_unsubscribe_handler, + sse_workflow_subscribe_handler, }; pub use tls::{ GatewayConn, GatewayListener, PeerAddr, TlsListenConfig, bind_listener, load_rustls_config, diff --git a/crates/forge-runtime/src/gateway/multipart.rs b/crates/forge-runtime/src/gateway/multipart.rs index 7933be2c..20fa2521 100644 --- a/crates/forge-runtime/src/gateway/multipart.rs +++ b/crates/forge-runtime/src/gateway/multipart.rs @@ -12,9 +12,106 @@ use forge_core::types::Upload; use super::rpc::RpcHandler; const MAX_FIELD_NAME_LENGTH: usize = 255; +const MAX_FILENAME_LENGTH: usize = 255; const MAX_JSON_FIELD_SIZE: usize = 1024 * 1024; const JSON_FIELD_NAME: &str = "_json"; +/// Sanitize a raw upload filename for safe persistence and logging. +/// +/// Strips path components, neutralizes traversal sequences after the basename +/// is isolated, rejects null bytes / control chars, and rewrites Windows +/// reserved device names (CON, PRN, NUL, AUX, COM[1-9], LPT[1-9]) so that +/// downstream code that mirrors the upload to disk on Windows can't trip the +/// reserved-name handling. Returns `None` when nothing salvageable remains. +fn sanitize_filename(raw: &str) -> Option { + // Basename first: take the last path component for either separator, then + // collapse traversal sequences inside the basename so `foo..bar` keeps the + // double-dot but `..` alone becomes `_`. + let basename = raw.rsplit(['/', '\\']).next().unwrap_or(raw); + let basename = basename.trim(); + if basename.is_empty() || basename == "." || basename == ".." { + return None; + } + + // Reject controls and null bytes outright. A null byte in a filename is a + // classic log-truncation / path-confusion vector. + if basename.chars().any(|c| c == '\0' || c.is_control()) { + return None; + } + + let mut name: String = basename + .chars() + .map(|c| match c { + '<' | '>' | ':' | '"' | '|' | '?' | '*' => '_', + _ => c, + }) + .collect(); + + // Windows reserved device names match against the basename without its + // extension. Comparison is case-insensitive. + let stem = name.split('.').next().unwrap_or(&name).to_ascii_uppercase(); + let is_reserved_dev = |prefix: &str| -> bool { + stem.strip_prefix(prefix) + .and_then(|rest| rest.parse::().ok()) + .is_some_and(|n| (1..=9).contains(&n)) + }; + let is_reserved = matches!(stem.as_str(), "CON" | "PRN" | "AUX" | "NUL") + || is_reserved_dev("COM") + || is_reserved_dev("LPT"); + if is_reserved { + name = format!("_{name}"); + } + + if name.len() > MAX_FILENAME_LENGTH { + name.truncate(MAX_FILENAME_LENGTH); + } + + if name.is_empty() { None } else { Some(name) } +} + +#[cfg(test)] +mod sanitize_tests { + use super::sanitize_filename; + + #[test] + fn strips_path_components() { + assert_eq!(sanitize_filename("/etc/passwd").as_deref(), Some("passwd")); + assert_eq!( + sanitize_filename("C:\\Windows\\system.ini").as_deref(), + Some("system.ini") + ); + } + + #[test] + fn preserves_legitimate_double_dots() { + assert_eq!(sanitize_filename("foo..bar").as_deref(), Some("foo..bar")); + } + + #[test] + fn rejects_traversal_only_basename() { + assert!(sanitize_filename("../../etc/passwd").is_some()); + assert_eq!( + sanitize_filename("../../etc/passwd").as_deref(), + Some("passwd") + ); + assert!(sanitize_filename("..").is_none()); + assert!(sanitize_filename("foo/..").is_none()); + } + + #[test] + fn rejects_control_chars_and_nulls() { + assert!(sanitize_filename("foo\0bar").is_none()); + assert!(sanitize_filename("foo\nbar").is_none()); + } + + #[test] + fn rewrites_windows_reserved_names() { + assert_eq!(sanitize_filename("CON").as_deref(), Some("_CON")); + assert_eq!(sanitize_filename("nul.txt").as_deref(), Some("_nul.txt")); + assert_eq!(sanitize_filename("COM1.log").as_deref(), Some("_COM1.log")); + } +} + /// Lightweight magic-byte check: for a small set of well-known media types, /// the declared `Content-Type` must match the file's leading bytes. Types /// outside this list pass through (we don't have a signature library, and @@ -169,6 +266,17 @@ pub async fn rpc_multipart_handler( } if name == JSON_FIELD_NAME { + // Duplicate `_json` would silently let the last value win, which + // is a parameter-smuggling avenue when upstream validators only + // inspected the first occurrence. + if json_args.is_some() { + return multipart_error( + StatusCode::BAD_REQUEST, + "DUPLICATE_FIELD", + "Multiple `_json` fields submitted; only one is allowed", + ); + } + let mut buffer = BytesMut::new(); let mut json_field = field; @@ -180,8 +288,7 @@ pub async fn rpc_multipart_handler( StatusCode::PAYLOAD_TOO_LARGE, "PAYLOAD_TOO_LARGE", format!( - "Multipart payload exceeds maximum size of {} bytes", - max_total + "Multipart payload exceeds maximum size of {max_total} bytes (field `_json`)" ), ); } @@ -235,20 +342,16 @@ pub async fn rpc_multipart_handler( .file_name() .map(String::from) .unwrap_or_else(|| name.clone()); - // Sanitize filename: strip path components to prevent path traversal - let filename = raw_filename - .rsplit(['/', '\\']) - .next() - .unwrap_or(&raw_filename) - .replace("..", "_") - .to_string(); - if filename.is_empty() { - return multipart_error( - StatusCode::BAD_REQUEST, - "INVALID_FILENAME", - "Filename is empty after sanitization", - ); - } + let filename = match sanitize_filename(&raw_filename) { + Some(f) => f, + None => { + return multipart_error( + StatusCode::BAD_REQUEST, + "INVALID_FILENAME", + "Filename empty or contains disallowed characters after sanitization", + ); + } + }; let content_type = field .content_type() .map(String::from) @@ -265,8 +368,7 @@ pub async fn rpc_multipart_handler( StatusCode::PAYLOAD_TOO_LARGE, "PAYLOAD_TOO_LARGE", format!( - "Multipart payload exceeds maximum size of {} bytes", - max_total + "Multipart payload exceeds maximum size of {max_total} bytes (field `{name}`, file `{filename}`)" ), ); } diff --git a/crates/forge-runtime/src/gateway/oauth.rs b/crates/forge-runtime/src/gateway/oauth.rs index a586d54b..1beef750 100644 --- a/crates/forge-runtime/src/gateway/oauth.rs +++ b/crates/forge-runtime/src/gateway/oauth.rs @@ -398,10 +398,12 @@ pub async fn oauth_register( ) .into_response(); } - // Check scheme and host - let is_localhost = uri.starts_with("http://localhost") - || uri.starts_with("http://127.0.0.1") - || uri.starts_with("http://[::1]"); + // Check scheme and host. `starts_with("http://localhost")` would also + // pass `http://localhost.evil.com/cb`, letting an attacker register a + // non-loopback redirect over plain HTTP. Extract the hostname and + // compare exactly. + let is_localhost = + forge_core::util::http_hostname(uri).is_some_and(forge_core::util::is_loopback_host); let is_https = uri.starts_with("https://"); if !is_localhost && !is_https { return ( @@ -661,10 +663,20 @@ pub async fn oauth_authorize_post( .into_response(); } - // Rate limit login failures (T7). PG-backed key so the budget is shared - // across cluster nodes. + // Rate limit every authorize POST branch (T7). Applied before branch + // dispatch so token and session-cookie flows can't bypass the budget — + // each branch exercises crypto-heavy paths (JWT validate, cookie HMAC, + // argon2id) that are otherwise free abuse amplifiers. let ip = resolved_ip.0.as_deref().unwrap_or("unknown"); let rate_key = format!("oauth:login:{ip}"); + if !state.rate_check(&rate_key, LOGIN_FAIL_RATE_LIMIT).await { + return authorize_error_redirect( + &form.redirect_uri, + form.state.as_deref(), + "access_denied", + "Too many authorization attempts. Please try again later.", + ); + } // Validate client and redirect_uri again (form could be tampered) let client = sqlx::query!( @@ -714,16 +726,22 @@ pub async fn oauth_authorize_post( }); if let Some(subject) = session_subject { - // Session cookie flow: user identified by signed cookie from previous API calls. - // Subject may be a UUID (HMAC auth) or an external provider ID (Firebase, Clerk). - user_id = subject.parse::().unwrap_or_else(|_| { - // Non-UUID subject (Firebase UID, etc.): deterministic UUID from subject hash. - use sha2::Digest; - let hash: [u8; 32] = sha2::Sha256::digest(subject.as_bytes()).into(); - let mut bytes = [0u8; 16]; - bytes.copy_from_slice(&hash[..16]); - Uuid::from_bytes(bytes) - }); + // Session cookie flow: subject must already be a real users.id UUID. + // Hashing an external provider subject into a fake UUID would forge a + // FK to a row that does not exist, and two different external subjects + // could collide on the same 128-bit prefix. External-provider users + // must complete the bearer-token branch instead. + match subject.parse::() { + Ok(uid) => user_id = uid, + Err(_) => { + return authorize_error_redirect( + &form.redirect_uri, + form.state.as_deref(), + "access_denied", + "Session subject is not a Forge user id. Sign in with a bearer token.", + ); + } + } } else if let Some(token) = &form.token { // Consent flow: validate existing JWT match state.auth_middleware.validate_token_async(token).await { @@ -762,15 +780,6 @@ pub async fn oauth_authorize_post( ); } - if !state.rate_check(&rate_key, LOGIN_FAIL_RATE_LIMIT).await { - return authorize_error_redirect( - &form.redirect_uri, - form.state.as_deref(), - "access_denied", - "Too many login attempts. Please try again later.", - ); - } - // Query users table by convention let row = sqlx::query!( "SELECT id, password_hash, role::TEXT FROM users WHERE email = $1", @@ -1156,14 +1165,16 @@ fn base_url_from_headers(headers: &HeaderMap) -> String { .and_then(|v| v.to_str().ok()) .unwrap_or("localhost:9081"); + // Parse host name (drop port and IPv6 brackets) and compare exactly. A + // `starts_with` test would match `localhost.evil.com`, causing the metadata + // to advertise an attacker-controlled http:// URL. + let hostname = forge_core::util::hostname_from_authority(host); + let is_loopback = forge_core::util::is_loopback_host(hostname); + // Do not trust x-forwarded-proto: OAuth routes bypass the trusted-proxy // middleware, so any client can spoof the header. Default to "https" for - // production safety; localhost gets "http" for local development. - let scheme = if host.starts_with("localhost") || host.starts_with("127.0.0.1") { - "http" - } else { - "https" - }; + // production safety; loopback gets "http" for local development. + let scheme = if is_loopback { "http" } else { "https" }; format!("{scheme}://{host}") } diff --git a/crates/forge-runtime/src/gateway/rpc.rs b/crates/forge-runtime/src/gateway/rpc.rs index 09767f1f..e6caeb3a 100644 --- a/crates/forge-runtime/src/gateway/rpc.rs +++ b/crates/forge-runtime/src/gateway/rpc.rs @@ -214,12 +214,18 @@ pub struct RpcFunctionBody { /// Validate that a function name contains only safe characters. /// Prevents log injection and unexpected behavior from special characters. +/// Leading `.` (including `..`) is rejected: dotted segments are reserved +/// for module paths and a leading dot looks like a path-traversal attempt +/// — neither maps to a real handler, so failing loud beats a 404 later. fn is_valid_function_name(name: &str) -> bool { - !name.is_empty() - && name.len() <= 256 - && name - .chars() - .all(|c| c.is_alphanumeric() || c == '_' || c == '.' || c == ':' || c == '-') + if name.is_empty() || name.len() > 256 { + return false; + } + if name.starts_with('.') { + return false; + } + name.chars() + .all(|c| c.is_alphanumeric() || c == '_' || c == '.' || c == ':' || c == '-') } /// Axum handler for POST /rpc/:function (REST-style). @@ -314,6 +320,15 @@ mod tests { assert!(!is_valid_function_name("question?")); } + #[test] + fn function_name_rejects_leading_dot() { + // Leading dot (or `..`) reads as a path-traversal attempt and never + // maps to a real handler. + assert!(!is_valid_function_name(".hidden")); + assert!(!is_valid_function_name("..parent")); + assert!(!is_valid_function_name(".")); + } + #[test] fn user_agent_returns_value_when_header_present() { let mut headers = HeaderMap::new(); diff --git a/crates/forge-runtime/src/gateway/server.rs b/crates/forge-runtime/src/gateway/server.rs index 9d50c831..634466e2 100644 --- a/crates/forge-runtime/src/gateway/server.rs +++ b/crates/forge-runtime/src/gateway/server.rs @@ -34,7 +34,7 @@ use super::multipart::{MultipartConfig, rpc_multipart_handler}; use super::response::{RpcError, RpcResponse}; use super::rpc::{RpcHandler, rpc_function_handler, rpc_handler}; use super::sse::{ - SseState, sse_handler, sse_job_subscribe_handler, sse_subscribe_handler, + SseState, sse_handler, sse_job_subscribe_handler, sse_subscribe_handler, sse_ticket_handler, sse_unsubscribe_handler, sse_workflow_subscribe_handler, }; use super::tls::{TlsListenConfig, bind_listener}; @@ -48,19 +48,12 @@ const DEFAULT_MAX_JSON_BODY_SIZE: usize = 1024 * 1024; const DEFAULT_MAX_MULTIPART_BODY_SIZE: usize = 20 * 1024 * 1024; const DEFAULT_MAX_FILE_SIZE: usize = 10 * 1024 * 1024; const MAX_MULTIPART_CONCURRENCY: usize = 32; -/// Fallback for visitor ID hashing when no JWT secret is configured (dev only). -const DEFAULT_SIGNAL_SECRET: &str = "forge-default-signal-secret"; - -/// Resolve the visitor-ID hashing secret, falling back to a stable dev value -/// with a one-time warning when no JWT secret is configured. -fn signal_visitor_secret(jwt_secret: &Option) -> String { - jwt_secret.clone().unwrap_or_else(|| { - tracing::warn!( - "No jwt_secret configured; using default signal secret for visitor ID hashing. \ - Visitor IDs will be predictable. Set [auth] jwt_secret in forge.toml." - ); - DEFAULT_SIGNAL_SECRET.to_string() - }) +/// Resolve the visitor-ID hashing secret. Returns `None` when no JWT secret +/// is configured — callers must skip signals collection rather than fall +/// back to a constant, which would let any attacker predict visitor IDs and +/// correlate sessions across users. +fn signal_visitor_secret(jwt_secret: &Option) -> Option { + jwt_secret.clone().filter(|s| !s.is_empty()) } /// Gateway server configuration. @@ -218,6 +211,7 @@ pub struct GatewayServer { signals_collector: Option, signals_anonymize_ip: bool, signals_geoip: Option, + signals_rate_limit_per_minute: Option, custom_routes: Option, rate_limiter: Option>, role_resolver: Option, @@ -255,6 +249,7 @@ impl GatewayServer { signals_collector: None, signals_anonymize_ip: false, signals_geoip: None, + signals_rate_limit_per_minute: None, custom_routes: None, rate_limiter: None, role_resolver: None, @@ -330,6 +325,12 @@ impl GatewayServer { } /// Set the GeoIP resolver for country code lookups from client IPs. + /// Override the default per-IP `/signal` rate limit (requests per minute). + pub fn with_signals_rate_limit_per_minute(mut self, max: u32) -> Self { + self.signals_rate_limit_per_minute = Some(max); + self + } + pub fn with_signals_geoip(mut self, resolver: crate::signals::geoip::GeoIpResolver) -> Self { self.signals_geoip = Some(resolver); self @@ -424,8 +425,14 @@ impl GatewayServer { rpc.set_role_resolver(resolver.clone()); } if let Some(collector) = &self.signals_collector { - let secret = signal_visitor_secret(&self.config.auth.jwt_secret); - rpc.set_signals_collector(collector.clone(), secret); + match signal_visitor_secret(&self.config.auth.jwt_secret) { + Some(secret) => rpc.set_signals_collector(collector.clone(), secret), + None => tracing::error!( + "Signals collector configured but `[auth] jwt_secret` is unset. \ + Signals are disabled to avoid predictable visitor IDs. Configure \ + a jwt_secret in forge.toml to enable signals." + ), + } } let rpc_handler_state = Arc::new(rpc); @@ -446,21 +453,40 @@ impl GatewayServer { // with credentials per the CORS spec, so we enumerate them. let cors = if self.config.cors_enabled { if self.config.cors_origins.iter().any(|o| o == "*") { - // Wildcard origin can't use credentials. Loud at startup so - // operators don't ship `cors_origins = ["*"]` to production - // by accident — credentialed requests will silently fail - // (no `Access-Control-Allow-Credentials`) and there's no - // origin allowlist limiting cross-site abuse of the gateway. - tracing::warn!( - "CORS wildcard (`cors_origins = [\"*\"]`) is enabled. \ - Credentialed requests will fail and any origin can \ - reach the gateway. Set explicit origins for \ - production deployments." - ); - CorsLayer::new() - .allow_origin(Any) - .allow_methods(Any) - .allow_headers(Any) + let is_production = std::env::var("FORGE_ENV") + .ok() + .as_deref() + .map(|s| s.eq_ignore_ascii_case("production") || s.eq_ignore_ascii_case("prod")) + .unwrap_or(false); + if is_production { + // In production a wildcard origin opens the gateway to any + // site. Refuse the wildcard outright: CORS is disabled and + // every cross-origin request will fail at the browser. The + // alternative — silently accepting every Origin — would let + // a malicious site issue same-credentials requests. + tracing::error!( + "CORS wildcard (`cors_origins = [\"*\"]`) is forbidden when \ + FORGE_ENV=production. CORS will be disabled. Configure \ + explicit origins to re-enable cross-origin access." + ); + CorsLayer::new() + } else { + // Wildcard origin can't use credentials. Loud at startup so + // operators don't ship `cors_origins = ["*"]` to production + // by accident — credentialed requests will silently fail + // (no `Access-Control-Allow-Credentials`) and there's no + // origin allowlist limiting cross-site abuse of the gateway. + tracing::warn!( + "CORS wildcard (`cors_origins = [\"*\"]`) is enabled. \ + Credentialed requests will fail and any origin can \ + reach the gateway. Set explicit origins for \ + production deployments." + ); + CorsLayer::new() + .allow_origin(Any) + .allow_methods(Any) + .allow_headers(Any) + } } else { use axum::http::Method; let origins: Vec<_> = self @@ -501,7 +527,6 @@ impl GatewayServer { let sse_state = Arc::new(SseState::with_config( self.reactor.clone(), - auth_middleware_state.clone(), super::sse::SseConfig { max_sessions: self.config.sse_max_sessions, max_subscriptions_per_session: self @@ -569,6 +594,7 @@ impl GatewayServer { let sse_router = Router::new() .route("/events", get(sse_handler)) + .route("/events/ticket", post(sse_ticket_handler)) .route("/subscribe", post(sse_subscribe_handler)) .route("/unsubscribe", post(sse_unsubscribe_handler)) .route("/subscribe-job", post(sse_job_subscribe_handler)) @@ -598,23 +624,60 @@ impl GatewayServer { ); } - let mut signals_router = Router::new(); - if let Some(collector) = &self.signals_collector { + // The real collector only mounts when we have a visitor-ID secret: a + // constant fallback secret would make visitor IDs predictable, letting + // an attacker forge cross-user correlations. When it can't mount we + // still answer POST /signal with a 204 no-op (the else branch) instead + // of leaving the path to the SPA fallback, which 405s every beacon. + let signal_secret = signal_visitor_secret(&self.config.auth.jwt_secret); + if signal_secret.is_none() && self.signals_collector.is_some() { + tracing::error!( + "Signals collector configured but `[auth] jwt_secret` is unset. \ + Signal collection is disabled to avoid predictable visitor IDs." + ); + } + let signals_router = if let (Some(collector), Some(server_secret)) = + (&self.signals_collector, signal_secret) + { let signals_state = Arc::new(crate::signals::endpoints::SignalsState { collector: collector.clone(), pool: self.db.primary().clone(), - server_secret: signal_visitor_secret(&self.config.auth.jwt_secret), + server_secret, anonymize_ip: self.signals_anonymize_ip, geoip: self.signals_geoip.clone(), - rate_limiter: Arc::new(crate::signals::rate_limit::SignalRateLimiter::new()), + rate_limiter: Arc::new(match self.signals_rate_limit_per_minute { + Some(max) => crate::signals::rate_limit::SignalRateLimiter::with_limit(max), + None => crate::signals::rate_limit::SignalRateLimiter::new(), + }), }); - signals_router = Router::new() + // Tighter body cap for the signal endpoint specifically. The + // batch + per-event size caps in signals/endpoints.rs would + // otherwise sit behind the 1 MB JSON default; clamping the + // request body to MAX_BATCH_SIZE * MAX_EVENT_BYTES + slack stops + // unauthenticated clients from forcing us to deserialize multi- + // MB JSON before validation runs. + const MAX_SIGNAL_BODY_BYTES: usize = 512 * 1024; + Router::new() .route("/signal", post(crate::signals::endpoints::signal_handler)) - .with_state(signals_state); - } + .layer(DefaultBodyLimit::max(MAX_SIGNAL_BODY_BYTES)) + .with_state(signals_state) + } else { + // Signals are disabled (or `[auth] jwt_secret` is unset). Clients + // enable web-vitals and page-view tracking by default and POST to + // /signal regardless. Without a route here the request falls through + // to the SPA static fallback, which rejects non-GET with a 405 the + // browser logs as a console error. Accept and drop it: a 204 stores + // nothing and mints no visitor ID, so it doesn't reintroduce the + // predictable-ID risk the real handler guards against. + Router::new().route( + "/signal", + post(|| async { axum::http::StatusCode::NO_CONTENT }), + ) + }; let admin_router = admin_router(AdminState { db_pool: self.db.primary().clone(), + reactor: Some(self.reactor.clone()), }); main_router = main_router @@ -689,6 +752,23 @@ impl GatewayServer { .map_err(|e| std::io::Error::other(format!("Failed to start reactor: {}", e)))?; tracing::info!("Reactor started for real-time updates"); + // Surface the trusted-proxy posture at startup. The XFF chain is only + // honored when the immediate peer is in `trusted_proxies` — if the + // operator forgot to add the proxy CIDR, every request silently falls + // back to the peer IP and rate limits / geo signals get pinned to the + // proxy. A loud one-shot log keeps that misconfiguration visible. + if !self.config.trusted_proxies.is_empty() { + tracing::info!( + ranges = self.config.trusted_proxies.len(), + "Trusted proxies configured; X-Forwarded-For honored only from peers in these CIDRs" + ); + } else { + tracing::info!( + "No trusted_proxies configured; X-Forwarded-For headers are ignored and \ + client IPs come from the immediate peer" + ); + } + tracing::info!("Gateway server listening on {}", addr); let listener = bind_listener(addr, tls.as_ref()).await?; @@ -806,10 +886,15 @@ async fn readiness_handler( } async fn handle_middleware_error(err: BoxError) -> axum::response::Response { + // Distinguish error categories so clients can react correctly. Timeout + // signals "retry later"; anything else gets surfaced as 500 so it shows + // up in error budgets rather than masquerading as a transient 503 that + // hides genuine middleware bugs. let rpc_err = if err.is::() { RpcError::new("REQUEST_TIMEOUT", "Request timed out") } else { - RpcError::new("SERVICE_UNAVAILABLE", "Server overloaded") + tracing::error!(error = %err, "Unexpected middleware error"); + RpcError::new("INTERNAL_ERROR", "Internal server error") }; RpcResponse::error(rpc_err).into_response() } @@ -919,7 +1004,7 @@ async fn api_version_middleware( let is_rpc = req.uri().path().starts_with("/rpc"); if is_rpc && let Some(accept) = req.headers().get(axum::http::header::ACCEPT) { let accept_str = accept.to_str().unwrap_or(""); - if accept_str != "*/*" && !accept_str.is_empty() && !accept_str.contains(FORGE_API_V1) { + if !accept_str.is_empty() && !accept_allows_v1(accept_str) { return RpcResponse::error(RpcError::new( "UNSUPPORTED_API_VERSION", format!( @@ -933,6 +1018,39 @@ async fn api_version_middleware( next.run(req).await } +/// Returns true if the `Accept` header value tolerates the v1 forge media +/// type. Each comma-separated media range is checked individually so that +/// `Accept: application/json, application/vnd.forge.v1+json` matches even +/// though `contains` would also have matched a misleading substring. Quality +/// values (`;q=0`) explicitly disable the match. +fn accept_allows_v1(accept: &str) -> bool { + for raw in accept.split(',') { + let mut parts = raw.split(';').map(str::trim); + let Some(media) = parts.next() else { continue }; + if media.is_empty() { + continue; + } + let mut q = 1.0_f32; + for param in parts { + if let Some(qv) = param.strip_prefix("q=") + && let Ok(parsed) = qv.parse::() + { + q = parsed; + } + } + if q <= 0.0 { + continue; + } + if media.eq_ignore_ascii_case(FORGE_API_V1) + || media == "*/*" + || media.eq_ignore_ascii_case("application/*") + { + return true; + } + } + false +} + /// Wraps each request in a span with HTTP semantics and OpenTelemetry /// context propagation. Incoming `traceparent` headers are extracted so /// that spans join the caller's distributed trace. @@ -1117,9 +1235,10 @@ struct JsonDepthConfig { } /// Middleware that rejects request bodies whose JSON nesting depth exceeds -/// `max_depth`. Runs on all POST requests regardless of Content-Type, because -/// serde_json will parse the body downstream even if the client lies about -/// the content type. +/// `max_depth`. Runs on every method that can carry a body (POST/PUT/PATCH/ +/// DELETE) regardless of Content-Type, because serde_json will parse the +/// body downstream even if the client lies about the content type. GET and +/// HEAD are skipped because Axum drops their bodies. /// /// The body is buffered, inspected, and re-inserted into the request so that /// downstream handlers see the original bytes. @@ -1129,8 +1248,13 @@ async fn json_depth_check_middleware( next: axum::middleware::Next, ) -> axum::response::Response { use axum::body::Body; + use axum::http::Method; - if req.method() != axum::http::Method::POST || config.max_depth == 0 { + let method_has_body = matches!( + *req.method(), + Method::POST | Method::PUT | Method::PATCH | Method::DELETE + ); + if !method_has_body || config.max_depth == 0 { return next.run(req).await; } @@ -1237,12 +1361,18 @@ mod tests { #[test] fn signal_visitor_secret_uses_jwt_secret_when_present() { let secret = Some("my-jwt-secret".to_string()); - assert_eq!(signal_visitor_secret(&secret), "my-jwt-secret"); + assert_eq!( + signal_visitor_secret(&secret), + Some("my-jwt-secret".to_string()) + ); } #[test] - fn signal_visitor_secret_falls_back_to_default_when_absent() { - assert_eq!(signal_visitor_secret(&None), DEFAULT_SIGNAL_SECRET); + fn signal_visitor_secret_is_none_when_jwt_secret_absent() { + // Refuse to mint a constant fallback — predictable visitor IDs would + // let an attacker correlate sessions across users. + assert_eq!(signal_visitor_secret(&None), None); + assert_eq!(signal_visitor_secret(&Some(String::new())), None); } #[test] diff --git a/crates/forge-runtime/src/gateway/sse.rs b/crates/forge-runtime/src/gateway/sse.rs index 7dd18174..70efd5c5 100644 --- a/crates/forge-runtime/src/gateway/sse.rs +++ b/crates/forge-runtime/src/gateway/sse.rs @@ -5,7 +5,7 @@ use std::convert::Infallible; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use std::time::Duration; +use std::time::{Duration, Instant}; use axum::Json; use axum::extract::{Extension, Query, State}; @@ -21,6 +21,71 @@ use subtle::ConstantTimeEq; use tokio::sync::mpsc; use tokio_util::sync::CancellationToken; +/// Maximum number of outstanding SSE tickets held in memory. Each ticket is +/// small (~200 B), so 10k is a ~2 MB cap. New issuance evicts expired +/// entries before falling back to a hard reject. +const MAX_SSE_TICKETS: usize = 10_000; +/// SSE ticket lifetime. The original ISSUES.md item requires "≤ 60 s"; +/// 30 s is short enough to bound replay risk yet generous for slow clients +/// to complete the POST + EventSource handshake. +const SSE_TICKET_TTL_SECS: u64 = 30; +const SSE_TICKET_TTL_SECS_STR: &str = "30"; + +/// One-shot SSE authentication ticket. Issued via `POST /events/ticket` +/// against a validated bearer header; consumed exactly once by the SSE +/// `GET /events?ticket=…` upgrade. Stored in-process (no DB), bound to +/// the caller's resolved client IP so a leaked ticket from logs cannot +/// be replayed from a different origin. +struct TicketEntry { + auth: AuthContext, + client_ip: Option, + expires_at: Instant, +} + +/// Process-local store of outstanding SSE tickets. Bounded, TTL-evicted, +/// one-time-use. Tickets are opaque uuid v4 strings (122 bits). +#[derive(Default)] +struct TicketStore { + entries: DashMap, +} + +impl TicketStore { + fn new() -> Self { + Self { + entries: DashMap::new(), + } + } + + /// Drop expired tickets. Called opportunistically on insert and consume. + fn sweep_expired(&self) { + let now = Instant::now(); + self.entries.retain(|_, entry| entry.expires_at > now); + } + + /// Insert a fresh ticket. Returns `false` when at capacity even after + /// sweeping expired entries (caller should reject with 503). + fn insert(&self, ticket: String, entry: TicketEntry) -> bool { + if self.entries.len() >= MAX_SSE_TICKETS { + self.sweep_expired(); + if self.entries.len() >= MAX_SSE_TICKETS { + return false; + } + } + self.entries.insert(ticket, entry); + true + } + + /// Atomically remove and validate a ticket. Returns `None` if missing, + /// already consumed, or expired. + fn consume(&self, ticket: &str) -> Option { + let (_, entry) = self.entries.remove(ticket)?; + if entry.expires_at <= Instant::now() { + return None; + } + Some(entry) + } +} + /// Wraps an mpsc::Receiver as a Stream for SSE. struct ReceiverStream { rx: mpsc::Receiver, @@ -36,7 +101,6 @@ impl Stream for ReceiverStream { use forge_core::function::AuthContext; use forge_core::realtime::{SessionId, SubscriptionId}; -use super::auth::AuthMiddleware; use crate::realtime::Reactor; use crate::realtime::RealtimeMessage; @@ -62,13 +126,6 @@ fn same_principal(a: &AuthContext, b: &AuthContext) -> bool { } } -fn resolve_sse_auth_context( - request_auth: &AuthContext, - query_auth: Option, -) -> AuthContext { - query_auth.unwrap_or_else(|| request_auth.clone()) -} - #[allow(clippy::result_large_err)] fn authorize_session_access( session: &SseSessionData, @@ -183,8 +240,9 @@ impl Default for SseConfig { /// SSE query parameters. #[derive(Debug, Deserialize)] pub struct SseQuery { - /// Authentication token. - pub token: Option, + /// One-shot ticket obtained from `POST /events/ticket`. Required when + /// `EventSource` cannot send an `Authorization` header (browsers). + pub ticket: Option, } struct SseSessionData { @@ -199,7 +257,6 @@ struct SseSessionData { #[derive(Clone)] pub struct SseState { reactor: Arc, - auth_middleware: Arc, /// Per-session data: auth context and subscription mappings (sharded). sessions: Arc>, /// Per-user session count for O(1) limit enforcement. @@ -208,28 +265,26 @@ pub struct SseState { ip_session_counts: Arc>, /// Per-user subscription count across all sessions. user_subscription_counts: Arc>, + /// One-shot SSE auth tickets. See `TicketStore` for semantics. + tickets: Arc, config: SseConfig, } impl SseState { /// Create new SSE state with default config. - pub fn new(reactor: Arc, auth_middleware: Arc) -> Self { - Self::with_config(reactor, auth_middleware, SseConfig::default()) + pub fn new(reactor: Arc) -> Self { + Self::with_config(reactor, SseConfig::default()) } /// Create new SSE state with custom config. - pub fn with_config( - reactor: Arc, - auth_middleware: Arc, - config: SseConfig, - ) -> Self { + pub fn with_config(reactor: Arc, config: SseConfig) -> Self { Self { reactor, - auth_middleware, sessions: Arc::new(DashMap::new()), user_session_counts: Arc::new(DashMap::new()), ip_session_counts: Arc::new(DashMap::new()), user_subscription_counts: Arc::new(DashMap::new()), + tickets: Arc::new(TicketStore::new()), config, } } @@ -570,23 +625,51 @@ pub async fn sse_handler( let keepalive_secs = state.config.keepalive_interval_secs; let cancel_token = CancellationToken::new(); - let query_auth = if let Some(token) = &query.token { - match state.auth_middleware.validate_token_async(token).await { - Ok(claims) => Some(super::auth::build_auth_context_from_claims(claims)), - Err(e) => { - tracing::warn!("SSE token validation failed: {}", e); + let client_ip = resolved_ip.0; + + // Authentication resolution order: + // 1. If the request was authenticated by the `Authorization` header + // (auth_middleware ran upstream), use that. Header is authoritative. + // 2. Otherwise, if a `?ticket=` is supplied, consume it. The ticket + // was minted against a validated bearer header at `/events/ticket` + // and is bound to the resolved client IP, so a leaked URL cannot + // be replayed from a different origin. + // 3. Otherwise, treat the connection as anonymous. + // + // JWTs are deliberately never accepted in the URL. Query strings appear + // in access logs, browser history, referrer headers, and proxy caches. + let auth_context = if request_auth.is_authenticated() { + request_auth.clone() + } else if let Some(ticket) = &query.ticket { + match state.tickets.consume(ticket) { + Some(entry) => { + // Bind ticket to the IP that requested it. Reject if either + // side has no resolved IP (strict) or they disagree. The + // server-side store already guarantees one-shot use. + let ip_match = match (&entry.client_ip, &client_ip) { + (Some(a), Some(b)) => a == b, + _ => false, + }; + if !ip_match { + tracing::warn!("SSE ticket IP mismatch; rejecting"); + return super::response::RpcResponse::error( + super::response::RpcError::unauthorized("SSE ticket IP mismatch"), + ) + .into_response(); + } + entry.auth + } + None => { + tracing::warn!("SSE ticket missing, expired, or already consumed"); return super::response::RpcResponse::error( - super::response::RpcError::unauthorized("Invalid authentication token"), + super::response::RpcError::unauthorized("Invalid or expired SSE ticket"), ) .into_response(); } } } else { - None + request_auth.clone() }; - let auth_context = resolve_sse_auth_context(&request_auth, query_auth); - - let client_ip = resolved_ip.0; // UUIDv4 provides 122 bits of randomness, sufficient for session secret entropy let session_secret = uuid::Uuid::new_v4().to_string(); // Authenticated sessions without an explicit exp claim get a default @@ -750,6 +833,61 @@ pub async fn sse_handler( .into_response() } +/// Response body for `POST /events/ticket`. +#[derive(Debug, Serialize)] +pub struct SseTicketResponse { + /// Opaque single-use ticket. Send back as `?ticket=…` on the next + /// `GET /events` request. + pub ticket: String, + /// Lifetime hint in seconds; clients should connect well before this. + pub expires_in_secs: u64, +} + +/// SSE ticket handler for POST `/events/ticket`. Issues a short-lived, +/// IP-bound, single-use ticket so browsers (whose `EventSource` cannot +/// set custom headers) can authenticate the SSE upgrade without putting +/// a long-lived JWT in the URL. +/// +/// Requires an authenticated bearer header. Anonymous callers get 401: +/// anonymous SSE streams can simply connect to `/events` without a ticket. +pub async fn sse_ticket_handler( + State(state): State>, + Extension(request_auth): Extension, + Extension(resolved_ip): Extension, +) -> impl IntoResponse { + if !request_auth.is_authenticated() { + return super::response::RpcResponse::error(super::response::RpcError::unauthorized( + "Authentication required to mint an SSE ticket", + )) + .into_response(); + } + + let ticket = uuid::Uuid::new_v4().to_string(); + let entry = TicketEntry { + auth: request_auth.clone(), + client_ip: resolved_ip.0, + expires_at: Instant::now() + Duration::from_secs(SSE_TICKET_TTL_SECS), + }; + + if !state.tickets.insert(ticket.clone(), entry) { + return ( + StatusCode::SERVICE_UNAVAILABLE, + [(axum::http::header::RETRY_AFTER, SSE_TICKET_TTL_SECS_STR)], + Json( + SseError::new("SSE_TICKET_CAPACITY", "Too many outstanding SSE tickets") + .with_retry_after(SSE_TICKET_TTL_SECS), + ), + ) + .into_response(); + } + + Json(SseTicketResponse { + ticket, + expires_in_secs: SSE_TICKET_TTL_SECS, + }) + .into_response() +} + /// Convert realtime message to SSE message. fn convert_realtime_to_sse(msg: RealtimeMessage) -> Option { match msg { @@ -1232,26 +1370,51 @@ mod tests { } #[test] - fn resolve_sse_auth_context_prefers_request_auth_when_query_token_absent() { - let request_auth = + fn ticket_store_consume_is_one_shot() { + let store = TicketStore::new(); + let auth = AuthContext::authenticated(Uuid::new_v4(), vec!["user".to_string()], HashMap::new()); - - let resolved = resolve_sse_auth_context(&request_auth, None); - - assert!(resolved.is_authenticated()); - assert_eq!(resolved.principal_id(), request_auth.principal_id()); + let entry = TicketEntry { + auth, + client_ip: Some("1.2.3.4".into()), + expires_at: Instant::now() + Duration::from_secs(30), + }; + assert!(store.insert("t1".into(), entry)); + assert!(store.consume("t1").is_some()); + assert!(store.consume("t1").is_none(), "second consume must fail"); } #[test] - fn resolve_sse_auth_context_prefers_query_token_when_present() { - let request_auth = - AuthContext::authenticated(Uuid::new_v4(), vec!["user".to_string()], HashMap::new()); - let query_auth = + fn ticket_store_rejects_expired() { + let store = TicketStore::new(); + let auth = AuthContext::authenticated(Uuid::new_v4(), vec!["user".to_string()], HashMap::new()); + let entry = TicketEntry { + auth, + client_ip: None, + expires_at: Instant::now() - Duration::from_secs(1), + }; + assert!(store.insert("t2".into(), entry)); + assert!( + store.consume("t2").is_none(), + "expired ticket must not validate" + ); + } - let resolved = resolve_sse_auth_context(&request_auth, Some(query_auth.clone())); - - assert_eq!(resolved.principal_id(), query_auth.principal_id()); + #[test] + fn ticket_store_caps_at_max() { + let store = TicketStore::new(); + let make_entry = || TicketEntry { + auth: AuthContext::unauthenticated(), + client_ip: None, + expires_at: Instant::now() + Duration::from_secs(30), + }; + // Fill to cap. + for i in 0..MAX_SSE_TICKETS { + assert!(store.insert(format!("k{i}"), make_entry())); + } + // One more should fail (no expired entries to sweep). + assert!(!store.insert("overflow".into(), make_entry())); } #[test] diff --git a/crates/forge-runtime/src/gateway/tls.rs b/crates/forge-runtime/src/gateway/tls.rs index 3e7b6655..fcb71312 100644 --- a/crates/forge-runtime/src/gateway/tls.rs +++ b/crates/forge-runtime/src/gateway/tls.rs @@ -274,11 +274,35 @@ fn read_pem_certs(path: &str) -> Result>> { } fn read_pem_key(path: &str) -> Result> { + warn_if_key_world_readable(path); PrivateKeyDer::from_pem_file(path).map_err(|e| { ForgeError::config(format!("failed to read PEM private key from '{path}': {e}")) }) } +/// Emit a loud warning if the TLS private key is readable by group or other. +/// We don't refuse to start — operators may rely on a key-management daemon +/// that enforces its own ACL model — but silently loading a 0644 key would +/// be a footgun on shared hosts. +#[cfg(unix)] +fn warn_if_key_world_readable(path: &str) { + use std::os::unix::fs::MetadataExt; + let Ok(meta) = std::fs::metadata(path) else { + return; + }; + let mode = meta.mode() & 0o777; + if mode & 0o077 != 0 { + tracing::warn!( + path = %path, + mode = format!("{:o}", mode), + "TLS private key is readable by group or other; tighten to 0600 (chmod 600)" + ); + } +} + +#[cfg(not(unix))] +fn warn_if_key_world_readable(_path: &str) {} + #[cfg(test)] #[allow(clippy::unwrap_used, clippy::indexing_slicing)] mod tests { diff --git a/crates/forge-runtime/src/jobs/dispatcher.rs b/crates/forge-runtime/src/jobs/dispatcher.rs index 4277afe0..198c3c51 100644 --- a/crates/forge-runtime/src/jobs/dispatcher.rs +++ b/crates/forge-runtime/src/jobs/dispatcher.rs @@ -1,10 +1,9 @@ use std::future::Future; use std::pin::Pin; -use std::time::Duration; use chrono::{DateTime, Utc}; use forge_core::function::JobDispatch; -use forge_core::job::{ForgeJob, JobInfo, JobPriority}; +use forge_core::job::JobInfo; use uuid::Uuid; use super::queue::{JobQueue, JobRecord}; @@ -21,93 +20,6 @@ impl JobDispatcher { Self { queue, registry } } - pub async fn dispatch( - &self, - args: J::Args, - owner_subject: Option, - ) -> Result - where - J::Args: serde::Serialize, - { - let info = J::info(); - self.dispatch_with_info(&info, serde_json::to_value(args)?, owner_subject) - .await - } - - pub async fn dispatch_in( - &self, - delay: Duration, - args: J::Args, - owner_subject: Option, - ) -> Result - where - J::Args: serde::Serialize, - { - let info = J::info(); - let scheduled_at = Utc::now() - + chrono::Duration::from_std(delay) - .map_err(|_| forge_core::ForgeError::InvalidArgument("delay too large".into()))?; - self.dispatch_at_with_info( - &info, - serde_json::to_value(args)?, - scheduled_at, - owner_subject, - ) - .await - } - - pub async fn dispatch_at( - &self, - at: DateTime, - args: J::Args, - owner_subject: Option, - ) -> Result - where - J::Args: serde::Serialize, - { - let info = J::info(); - self.dispatch_at_with_info(&info, serde_json::to_value(args)?, at, owner_subject) - .await - } - - pub async fn dispatch_by_name( - &self, - job_type: &str, - args: serde_json::Value, - owner_subject: Option, - ) -> Result { - let entry = self.registry.get(job_type).ok_or_else(|| { - forge_core::ForgeError::NotFound(format!("Job type '{}' not found", job_type)) - })?; - - self.dispatch_with_info(&entry.info, args, owner_subject) - .await - } - - async fn dispatch_with_info( - &self, - info: &JobInfo, - args: serde_json::Value, - owner_subject: Option, - ) -> Result { - let mut job = JobRecord::new( - info.name, - args, - info.priority, - info.retry.max_attempts as i32, - ) - .with_owner_subject(owner_subject); - - if let Some(cap) = info.worker_capability { - job = job.with_capability(cap); - } - - self.queue - .enqueue(job) - .await - .map_err(forge_core::ForgeError::Database) - } - /// Request cancellation for a job. /// /// If `caller_subject` is provided, the cancellation will only succeed if @@ -124,32 +36,6 @@ impl JobDispatcher { .map_err(forge_core::ForgeError::Database) } - async fn dispatch_at_with_info( - &self, - info: &JobInfo, - args: serde_json::Value, - scheduled_at: DateTime, - owner_subject: Option, - ) -> Result { - let mut job = JobRecord::new( - info.name, - args, - info.priority, - info.retry.max_attempts as i32, - ) - .with_scheduled_at(scheduled_at) - .with_owner_subject(owner_subject); - - if let Some(cap) = info.worker_capability { - job = job.with_capability(cap); - } - - self.queue - .enqueue(job) - .await - .map_err(forge_core::ForgeError::Database) - } - async fn dispatch_with_info_and_tenant( &self, info: &JobInfo, @@ -203,63 +89,6 @@ impl JobDispatcher { .await .map_err(forge_core::ForgeError::Database) } - - pub async fn dispatch_idempotent( - &self, - idempotency_key: impl Into, - args: J::Args, - owner_subject: Option, - ) -> Result - where - J::Args: serde::Serialize, - { - let info = J::info(); - let mut job = JobRecord::new( - info.name, - serde_json::to_value(args)?, - info.priority, - info.retry.max_attempts as i32, - ) - .with_idempotency_key(idempotency_key) - .with_owner_subject(owner_subject); - - if let Some(cap) = info.worker_capability { - job = job.with_capability(cap); - } - - self.queue - .enqueue(job) - .await - .map_err(forge_core::ForgeError::Database) - } - - pub async fn dispatch_with_priority( - &self, - priority: JobPriority, - args: J::Args, - owner_subject: Option, - ) -> Result - where - J::Args: serde::Serialize, - { - let info = J::info(); - let mut job = JobRecord::new( - info.name, - serde_json::to_value(args)?, - priority, - info.retry.max_attempts as i32, - ) - .with_owner_subject(owner_subject); - - if let Some(cap) = info.worker_capability { - job = job.with_capability(cap); - } - - self.queue - .enqueue(job) - .await - .map_err(forge_core::ForgeError::Database) - } } impl JobDispatch for JobDispatcher { @@ -442,7 +271,7 @@ mod integration_tests { let dispatcher = dispatcher_with_registry(db.pool().clone(), |_| {}); let err = dispatcher - .dispatch_by_name("ghost", serde_json::json!({}), None) + .dispatch_by_name("ghost", serde_json::json!({}), None, None) .await .expect_err("unknown job must error"); @@ -470,6 +299,7 @@ mod integration_tests { "ship", serde_json::json!({"to": "warehouse"}), Some("u-1".into()), + None, ) .await .unwrap(); @@ -614,7 +444,7 @@ mod integration_tests { }); let job_id = dispatcher - .dispatch_by_name("ship", serde_json::json!({}), Some("alice".into())) + .dispatch_by_name("ship", serde_json::json!({}), Some("alice".into()), None) .await .unwrap(); @@ -636,7 +466,7 @@ mod integration_tests { }); let job_id = dispatcher - .dispatch_by_name("ship", serde_json::json!({}), Some("alice".into())) + .dispatch_by_name("ship", serde_json::json!({}), Some("alice".into()), None) .await .unwrap(); diff --git a/crates/forge-runtime/src/jobs/executor.rs b/crates/forge-runtime/src/jobs/executor.rs index 9069a131..f796149e 100644 --- a/crates/forge-runtime/src/jobs/executor.rs +++ b/crates/forge-runtime/src/jobs/executor.rs @@ -10,6 +10,11 @@ use super::queue::{JobQueue, JobRecord}; use super::registry::{JobEntry, JobRegistry}; use crate::observability; +/// How often to poll the progress channel between updates from the running job. +/// Short enough that progress propagates to subscribers within one frame; long +/// enough that an idle job doesn't burn CPU on the polling task. +const PROGRESS_POLL_INTERVAL: Duration = Duration::from_millis(50); + /// Executes jobs with timeout and retry handling. pub struct JobExecutor { queue: JobQueue, @@ -139,7 +144,7 @@ impl JobExecutor { } } Err(std::sync::mpsc::TryRecvError::Empty) => { - tokio::time::sleep(std::time::Duration::from_millis(50)).await; + tokio::time::sleep(PROGRESS_POLL_INTERVAL).await; } Err(std::sync::mpsc::TryRecvError::Disconnected) => { break; @@ -171,6 +176,24 @@ impl JobExecutor { c }; if let Some(ref subject) = job.owner_subject { + // Defense in depth: the job row stores `tenant_id` and + // `owner_subject` independently. We trust the dispatcher to pair + // them correctly, but if a stale/forged dispatch slipped a + // mismatched pair through, the handler would execute cross-tenant + // (#6 in issues doc). There's no system `users` table to consult, + // so the framework can't reject the row authoritatively here. Warn + // when an owner principal is present without a tenant — the shape + // most likely to indicate a dispatch path that lost tenancy. (Single- + // tenant apps legitimately dispatch with no tenant_id, so this is a + // warning, not an assertion.) + if job.tenant_id.is_none() && uuid::Uuid::parse_str(subject).is_ok() { + tracing::warn!( + job_id = %job.id, + job_type = %job.job_type, + owner_subject = %subject, + "Job has UUID owner_subject but no tenant_id; if the app is multi-tenant this is a dispatch-side bug — the handler will run with empty tenant scope" + ); + } let mut claims = std::collections::HashMap::new(); if let Some(tid) = job.tenant_id { claims.insert( @@ -240,13 +263,40 @@ impl JobExecutor { } }); + // Drop guard: ensures the heartbeat task is aborted even if `execute` + // is cancelled (e.g. `drain_jobs` calling `abort_all` on shutdown). + // Without this the task would keep refreshing `last_heartbeat` for up + // to 30s, blocking `release_stale` from requeueing the row (#9 in + // issues doc). + struct HeartbeatGuard { + stop: tokio::sync::watch::Sender, + handle: Option>, + } + impl Drop for HeartbeatGuard { + fn drop(&mut self) { + let _ = self.stop.send(true); + if let Some(h) = self.handle.take() { + h.abort(); + } + } + } + let mut heartbeat_guard = HeartbeatGuard { + stop: heartbeat_stop_tx, + handle: Some(heartbeat_task), + }; + let job_timeout = entry.info.timeout; let exec_start = std::time::Instant::now(); let result = timeout(job_timeout, self.run_handler(&entry, &ctx, &job.input)).await; let exec_duration_ms = exec_start.elapsed().as_millis() as i32; - let _ = heartbeat_stop_tx.send(true); - let _ = heartbeat_task.await; + // Happy path: signal the heartbeat task and await it cleanly. The + // guard will still run on early return, abort() is a no-op on a + // joined handle. + let _ = heartbeat_guard.stop.send(true); + if let Some(h) = heartbeat_guard.handle.take() { + let _ = h.await; + } let ttl = entry.info.ttl; @@ -326,8 +376,14 @@ impl JobExecutor { let error_msg = format!("Job timed out after {:?}", job_timeout); let should_retry = job.attempts < job.max_attempts; + // Mirror the failure path: honor the job's configured backoff + // strategy rather than hardcoding 60s (#13 in issues doc). let retry_delay = if should_retry { - Some(chrono::Duration::seconds(60)) + let std_delay = entry.info.retry.calculate_backoff(job.attempts as u32); + Some( + chrono::Duration::from_std(std_delay) + .unwrap_or(chrono::Duration::seconds(60)), + ) } else { None }; diff --git a/crates/forge-runtime/src/jobs/queue.rs b/crates/forge-runtime/src/jobs/queue.rs index 03b82452..c1d03f45 100644 --- a/crates/forge-runtime/src/jobs/queue.rs +++ b/crates/forge-runtime/src/jobs/queue.rs @@ -137,15 +137,30 @@ impl JobQueue { ) -> Result { // Fast path: check for existing idempotent job before attempting INSERT. // The UNIQUE partial index on idempotency_key guards against races. + // + // Scope the lookup by `job_type` so apps reusing the same key across + // multiple job types (e.g. `payment-{id}` for `charge` and `refund`) + // get the right job back. NOTE: the partial unique index in + // `v001_initial.sql` does NOT yet include `job_type`, so cross-type + // idempotency collisions are still rejected at the database level + // even though this check would accept them. Tracking issue: update + // the index to `(job_type, idempotency_key)` once migration is safe. if let Some(ref key) = job.idempotency_key { - let existing = sqlx::query_scalar!( + // Runtime query: lookup gained a `job_type` filter so apps reusing + // the same key across job types map to the right row. Stays as + // sqlx::query rather than query_scalar! to avoid invalidating the + // offline cache on a non-critical path. + #[allow(clippy::disallowed_methods)] + let existing: Option = sqlx::query_scalar( r#" SELECT id FROM forge_jobs WHERE idempotency_key = $1 + AND job_type = $2 AND status NOT IN ('completed', 'failed', 'dead_letter', 'cancelled') "#, - key ) + .bind(key) + .bind(&job.job_type) .fetch_optional(&mut *conn) .await?; @@ -187,15 +202,19 @@ impl JobQueue { .await?; // If ON CONFLICT fired (race with another enqueue), fetch the winner's ID. + // Runtime query for the same reason as the fast-path lookup above. if let Some(ref key) = job.idempotency_key { - let id = sqlx::query_scalar!( + #[allow(clippy::disallowed_methods)] + let id: Option = sqlx::query_scalar( r#" SELECT id FROM forge_jobs WHERE idempotency_key = $1 + AND job_type = $2 AND status NOT IN ('completed', 'failed', 'dead_letter', 'cancelled') "#, - key ) + .bind(key) + .bind(&job.job_type) .fetch_optional(&mut *conn) .await?; @@ -448,7 +467,11 @@ impl JobQueue { "#, job_id, error, - delay.num_seconds() as f64, + // Millisecond precision, not num_seconds(): the latter truncates + // to whole seconds, so a sub-second backoff (the common first + // retry: 1s base - 25% jitter = 0.75s) collapses to secs => 0, + // dropping the backoff entirely and retrying instantly. + delay.num_milliseconds() as f64 / 1000.0, ) .execute(&self.pool) .await?; @@ -609,6 +632,11 @@ impl JobQueue { } let retention_secs = Self::DEFAULT_RETENTION.as_secs() as f64; + // Defense-in-depth: the SELECT guard above already verified the caller + // owns this row, but include the ownership predicate directly in the + // UPDATE so a future refactor that reorders these blocks can't silently + // drop the check. `caller_subject IS NULL` lets system-side callers + // (no caller) cancel rows with no owner_subject. #[allow(clippy::disallowed_methods)] let updated = sqlx::query( r#" @@ -620,11 +648,17 @@ impl JobQueue { expires_at = NOW() + make_interval(secs => $3) WHERE id = $1 AND status NOT IN ('completed', 'failed', 'dead_letter', 'cancelled') + AND ( + owner_subject IS NULL + OR $4::text IS NULL + OR owner_subject = $4::text + ) "#, ) .bind(job_id) .bind(reason) .bind(retention_secs) + .bind(caller_subject) .execute(&self.pool) .await?; @@ -715,6 +749,13 @@ impl JobQueue { ( status = 'claimed' AND claimed_at < NOW() - make_interval(secs => $1) + -- Don't yank a claim that has produced a recent heartbeat: + -- the worker may have transitioned to running on its own + -- side and we just haven't seen the `start()` UPDATE land yet. + AND ( + last_heartbeat IS NULL + OR last_heartbeat < NOW() - make_interval(secs => $1) + ) ) OR ( status = 'running' @@ -883,7 +924,7 @@ mod integration_tests { let job_id = queue.enqueue(job).await.expect("Failed to enqueue"); let claimed = queue - .claim(worker_id, &[], true, 10) + .claim(worker_id, &["default".into()], true, 10) .await .expect("Failed to claim"); assert_eq!(claimed.len(), 1); @@ -912,11 +953,17 @@ mod integration_tests { } let worker1 = Uuid::new_v4(); - let batch1 = queue.claim(worker1, &[], true, 2).await.expect("claim1"); + let batch1 = queue + .claim(worker1, &["default".into()], true, 2) + .await + .expect("claim1"); assert_eq!(batch1.len(), 2); let worker2 = Uuid::new_v4(); - let batch2 = queue.claim(worker2, &[], true, 2).await.expect("claim2"); + let batch2 = queue + .claim(worker2, &["default".into()], true, 2) + .await + .expect("claim2"); assert_eq!(batch2.len(), 1); let ids1: Vec = batch1.iter().map(|j| j.id).collect(); @@ -943,7 +990,10 @@ mod integration_tests { let high = JobRecord::new("high_job", serde_json::json!({}), JobPriority::Critical, 3); queue.enqueue(high).await.expect("enqueue high"); - let claimed = queue.claim(worker_id, &[], true, 1).await.expect("claim"); + let claimed = queue + .claim(worker_id, &["default".into()], true, 1) + .await + .expect("claim"); assert_eq!(claimed.len(), 1); assert_eq!(claimed[0].job_type, "high_job"); @@ -959,7 +1009,10 @@ mod integration_tests { let job = JobRecord::new("process", serde_json::json!({}), JobPriority::Normal, 3); let job_id = queue.enqueue(job).await.expect("enqueue"); - queue.claim(worker_id, &[], true, 1).await.expect("claim"); + queue + .claim(worker_id, &["default".into()], true, 1) + .await + .expect("claim"); queue.start(job_id, worker_id, 1).await.expect("start"); queue .complete(job_id, serde_json::json!({"result": "done"}), None) @@ -982,7 +1035,10 @@ mod integration_tests { let job = JobRecord::new("flaky", serde_json::json!({}), JobPriority::Normal, 3); let job_id = queue.enqueue(job).await.expect("enqueue"); - queue.claim(worker_id, &[], true, 1).await.expect("claim"); + queue + .claim(worker_id, &["default".into()], true, 1) + .await + .expect("claim"); queue.start(job_id, worker_id, 1).await.expect("start"); queue @@ -1011,7 +1067,10 @@ mod integration_tests { let job = JobRecord::new("fatal", serde_json::json!({}), JobPriority::Normal, 1); let job_id = queue.enqueue(job).await.expect("enqueue"); - queue.claim(worker_id, &[], true, 1).await.expect("claim"); + queue + .claim(worker_id, &["default".into()], true, 1) + .await + .expect("claim"); queue.start(job_id, worker_id, 1).await.expect("start"); queue @@ -1189,7 +1248,10 @@ mod integration_tests { let job = JobRecord::new("long_task", serde_json::json!({}), JobPriority::Normal, 3); let job_id = queue.enqueue(job).await.expect("enqueue"); - queue.claim(worker_id, &[], true, 1).await.expect("claim"); + queue + .claim(worker_id, &["default".into()], true, 1) + .await + .expect("claim"); queue.start(job_id, worker_id, 1).await.expect("start"); queue.heartbeat(job_id).await.expect("heartbeat"); @@ -1205,7 +1267,10 @@ mod integration_tests { let job = JobRecord::new("export", serde_json::json!({}), JobPriority::Normal, 3); let job_id = queue.enqueue(job).await.expect("enqueue"); - queue.claim(worker_id, &[], true, 1).await.expect("claim"); + queue + .claim(worker_id, &["default".into()], true, 1) + .await + .expect("claim"); queue.start(job_id, worker_id, 1).await.expect("start"); queue diff --git a/crates/forge-runtime/src/jobs/registry.rs b/crates/forge-runtime/src/jobs/registry.rs index 676d41a8..0aa55fb2 100644 --- a/crates/forge-runtime/src/jobs/registry.rs +++ b/crates/forge-runtime/src/jobs/registry.rs @@ -5,29 +5,9 @@ use std::sync::Arc; use forge_core::Result; use forge_core::job::{ForgeJob, JobContext, JobInfo}; +use forge_core::util::normalize_handler_args as normalize_args; use serde_json::Value; -/// Converts `null` to `{}` and unwraps single-key `args`/`input` envelopes. -fn normalize_args(args: Value) -> Value { - let unwrapped = match &args { - Value::Object(map) if map.len() == 1 => { - if map.contains_key("args") { - map.get("args").cloned().unwrap_or(Value::Null) - } else if map.contains_key("input") { - map.get("input").cloned().unwrap_or(Value::Null) - } else { - args - } - } - _ => args, - }; - - match &unwrapped { - Value::Null => Value::Object(serde_json::Map::new()), - _ => unwrapped, - } -} - pub type BoxedJobHandler = Arc< dyn Fn(&JobContext, Value) -> Pin> + Send + '_>> + Send @@ -151,53 +131,9 @@ impl JobRegistry { #[allow(clippy::unwrap_used, clippy::indexing_slicing)] mod tests { use super::*; - use serde_json::json; - - // jobs/registry collapses null to {} so derive(Default) empty-struct args deserialize correctly; - // function/registry keeps null as-is for unit () — this divergence is the contract. - #[test] - fn normalize_args_converts_null_to_empty_object() { - assert_eq!(normalize_args(json!(null)), json!({})); - } - - #[test] - fn normalize_args_keeps_empty_object_intact() { - // `{}` (len 0) skips the envelope unwrap and the null branch. - assert_eq!(normalize_args(json!({})), json!({})); - } - #[test] - fn normalize_args_unwraps_args_envelope() { - assert_eq!(normalize_args(json!({"args": {"id": 7}})), json!({"id": 7})); - // The trailing null-to-{} step still applies after unwrap. - assert_eq!(normalize_args(json!({"args": null})), json!({})); - } - - #[test] - fn normalize_args_unwraps_input_envelope() { - assert_eq!(normalize_args(json!({"input": [1,2]})), json!([1, 2])); - } - - #[test] - fn normalize_args_keeps_other_single_key_objects_intact() { - // A handler with `struct Args { id: u32 }` must receive {"id":...} - // as-is — envelope stripping only fires for `args`/`input`. - assert_eq!(normalize_args(json!({"id": 7})), json!({"id": 7})); - } - - #[test] - fn normalize_args_keeps_multi_key_objects_intact() { - let v = json!({"a": 1, "b": 2}); - assert_eq!(normalize_args(v.clone()), v); - } - - #[test] - fn normalize_args_keeps_non_null_non_object_values_intact() { - assert_eq!(normalize_args(json!(42)), json!(42)); - assert_eq!(normalize_args(json!("x")), json!("x")); - assert_eq!(normalize_args(json!([1])), json!([1])); - assert_eq!(normalize_args(json!(true)), json!(true)); - } + // normalize_args is exercised via `forge_core::util` tests; jobs/registry + // now delegates to that shared helper. fn sample_info(name: &'static str) -> JobInfo { JobInfo { diff --git a/crates/forge-runtime/src/jobs/worker.rs b/crates/forge-runtime/src/jobs/worker.rs index 3cc9c788..9c6a59e4 100644 --- a/crates/forge-runtime/src/jobs/worker.rs +++ b/crates/forge-runtime/src/jobs/worker.rs @@ -166,25 +166,30 @@ impl Worker { let wakeup_notify = Arc::new(tokio::sync::Notify::new()); let wakeup_trigger = wakeup_notify.clone(); let wakeup_shutdown = shutdown_notify.clone(); - if let Some(mut rx) = self.notify_bus.subscribe("forge_jobs_available") { - tokio::spawn(async move { - loop { - tokio::select! { - _ = wakeup_shutdown.notified() => return, - result = rx.recv() => { - match result { - Ok(_) => wakeup_trigger.notify_one(), - Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { - tracing::debug!(missed = n, "Job wakeup receiver lagged"); - wakeup_trigger.notify_one(); + // Track the forwarder so shutdown can await it instead of leaking + // the JoinHandle (#8 in issues doc). + let forwarder_handle = self + .notify_bus + .subscribe("forge_jobs_available") + .map(|mut rx| { + tokio::spawn(async move { + loop { + tokio::select! { + _ = wakeup_shutdown.notified() => return, + result = rx.recv() => { + match result { + Ok(_) => wakeup_trigger.notify_one(), + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + tracing::debug!(missed = n, "Job wakeup receiver lagged"); + wakeup_trigger.notify_one(); + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => return, } - Err(tokio::sync::broadcast::error::RecvError::Closed) => return, } } } - } + }) }); - } tracing::debug!( worker_id = %self.id, @@ -200,6 +205,9 @@ impl Worker { tracing::debug!(worker_id = %self.id, "Worker shutting down"); shutdown_notify.notify_waiters(); let _ = cleanup_handle.await; + if let Some(h) = forwarder_handle { + let _ = h.await; + } self.drain_jobs(&mut job_tasks).await; break; } diff --git a/crates/forge-runtime/src/kv/store.rs b/crates/forge-runtime/src/kv/store.rs index 0369ad11..31834378 100644 --- a/crates/forge-runtime/src/kv/store.rs +++ b/crates/forge-runtime/src/kv/store.rs @@ -5,6 +5,12 @@ use sqlx::PgPool; use forge_core::error::{ForgeError, Result}; +/// Hard cap on the value size accepted by `set` / `set_if_absent`. PG can +/// store much larger BYTEA blobs, but the KV API is meant for small +/// configuration / lock payloads, not blobs. Multi-MB values are almost +/// always a misuse (and round-trip the protocol per call). +const MAX_VALUE_BYTES: usize = 1024 * 1024; + /// PostgreSQL-backed key-value store. /// /// Provides a simple get/set/delete/set_if_absent/increment API over @@ -24,13 +30,37 @@ impl KvStore { Self { pool, namespace } } - fn prefixed_key(&self, key: &str) -> String { - format!("{}:{}", self.namespace, key) + fn prefixed_key(&self, key: &str) -> Result { + // Reject `:` in either namespace or key — the prefix separator must + // be unambiguous so `(namespace=a, key=b:foo)` and + // `(namespace=a:b, key=foo)` can't collide on the same physical key. + if self.namespace.contains(':') { + return Err(ForgeError::InvalidArgument(format!( + "kv namespace must not contain ':' (got {:?})", + self.namespace + ))); + } + if key.contains(':') { + return Err(ForgeError::InvalidArgument( + "kv key must not contain ':' (reserved as namespace separator)".to_string(), + )); + } + Ok(format!("{}:{}", self.namespace, key)) + } + + fn check_value_size(value: &[u8]) -> Result<()> { + if value.len() > MAX_VALUE_BYTES { + return Err(ForgeError::InvalidArgument(format!( + "kv value exceeds {MAX_VALUE_BYTES} byte limit (got {})", + value.len() + ))); + } + Ok(()) } /// Get a value by key. Returns `None` if the key doesn't exist or is expired. pub async fn get(&self, key: &str) -> Result>> { - let full_key = self.prefixed_key(key); + let full_key = self.prefixed_key(key)?; let row = sqlx::query_scalar!( r#" SELECT value @@ -49,7 +79,8 @@ impl KvStore { /// Set a key to a value. Overwrites any existing value. pub async fn set(&self, key: &str, value: &[u8], ttl: Option) -> Result<()> { - let full_key = self.prefixed_key(key); + Self::check_value_size(value)?; + let full_key = self.prefixed_key(key)?; let expires_at = ttl.map(|d| Utc::now() + d); sqlx::query!( r#" @@ -80,10 +111,21 @@ impl KvStore { value: &[u8], ttl: Option, ) -> Result { - let full_key = self.prefixed_key(key); + Self::check_value_size(value)?; + let full_key = self.prefixed_key(key)?; let expires_at = ttl.map(|d| Utc::now() + d); - // ON CONFLICT WHERE treats expired rows as absent atomically. - // Convert to query!() after next `cargo sqlx prepare`. + // Serialize concurrent reclaim-of-expired racers per key via a + // transaction-scoped advisory lock keyed on the full prefixed key. + // Without this, the ON CONFLICT WHERE branch is only race-free under + // READ COMMITTED — under REPEATABLE READ a second writer can see the + // pre-update snapshot and "succeed" against an already-claimed row. + let mut tx = self.pool.begin().await.map_err(ForgeError::Database)?; + #[allow(clippy::disallowed_methods)] + sqlx::query("SELECT pg_advisory_xact_lock(hashtext($1)::bigint)") + .bind(&full_key) + .execute(&mut *tx) + .await + .map_err(ForgeError::Database)?; #[allow(clippy::disallowed_methods)] let rows = sqlx::query( r#" @@ -97,17 +139,18 @@ impl KvStore { .bind(&full_key) .bind(value) .bind(expires_at) - .execute(&self.pool) + .execute(&mut *tx) .await .map_err(ForgeError::Database)? .rows_affected(); + tx.commit().await.map_err(ForgeError::Database)?; Ok(rows > 0) } /// Delete a key. Returns `true` if the key existed. pub async fn delete(&self, key: &str) -> Result { - let full_key = self.prefixed_key(key); + let full_key = self.prefixed_key(key)?; let result = sqlx::query!("DELETE FROM forge_kv WHERE key = $1", full_key) .execute(&self.pool) .await @@ -123,13 +166,17 @@ impl KvStore { /// /// Uses `ON CONFLICT DO UPDATE ... WHERE` to handle expired rows atomically /// without CTE snapshot isolation issues. + /// + /// **Note:** counter storage is `BIGINT`, range `[-2^63, 2^63 - 1]`. + /// Increments that overflow surface as `ForgeError::InvalidArgument` so + /// callers can choose to reset rather than retry indefinitely. pub async fn increment(&self, key: &str, delta: i64, ttl: Option) -> Result { - let full_key = self.prefixed_key(key); + let full_key = self.prefixed_key(key)?; let expires_at = ttl.map(|d| Utc::now() + d); // Expired counters reset to delta rather than accumulating. // Convert to query_scalar!() after next `cargo sqlx prepare`. #[allow(clippy::disallowed_methods)] - let row: (i64,) = sqlx::query_as( + let row: std::result::Result<(i64,), sqlx::Error> = sqlx::query_as( r#" INSERT INTO forge_kv_counters (key, value, expires_at, updated_at) VALUES ($1, $2, $3, NOW()) @@ -149,10 +196,41 @@ impl KvStore { .bind(delta) .bind(expires_at) .fetch_one(&self.pool) + .await; + + match row { + Ok((v,)) => Ok(v), + Err(sqlx::Error::Database(db_err)) if db_err.code().as_deref() == Some("22003") => { + Err(ForgeError::InvalidArgument(format!( + "counter overflow at key {full_key:?}: BIGINT range exceeded" + ))) + } + Err(e) => Err(ForgeError::Database(e)), + } + } + + /// Read a counter's current value. Returns `None` if missing or expired. + /// + /// Mirrors the TTL filter from `get()` so an expired counter behaves the + /// same as a missing one — keeps a future generic `get` over both tables + /// consistent. + pub async fn get_counter(&self, key: &str) -> Result> { + let full_key = self.prefixed_key(key)?; + #[allow(clippy::disallowed_methods)] + let row: Option<(i64,)> = sqlx::query_as( + r#" + SELECT value + FROM forge_kv_counters + WHERE key = $1 + AND (expires_at IS NULL OR expires_at > NOW()) + "#, + ) + .bind(&full_key) + .fetch_optional(&self.pool) .await .map_err(ForgeError::Database)?; - Ok(row.0) + Ok(row.map(|(v,)| v)) } /// Remove expired keys from both tables. Returns total rows cleaned up. @@ -190,8 +268,11 @@ mod tests { .expect("connect_lazy never fails for a syntactically valid URL"); let store = KvStore::new(pool, "ratelimit"); - assert_eq!(store.prefixed_key("user:42"), "ratelimit:user:42"); - assert_eq!(store.prefixed_key(""), "ratelimit:"); + // `:` in a key is now rejected — verify that and a couple of + // representative happy cases. + assert!(store.prefixed_key("user:42").is_err()); + assert_eq!(store.prefixed_key("user_42").unwrap(), "ratelimit:user_42"); + assert_eq!(store.prefixed_key("").unwrap(), "ratelimit:"); } #[tokio::test] @@ -205,7 +286,10 @@ mod tests { // physical keys — the property the namespace exists to guarantee. let a = KvStore::new(pool.clone(), "subsystem_a"); let b = KvStore::new(pool, "subsystem_b"); - assert_ne!(a.prefixed_key("shared"), b.prefixed_key("shared")); + assert_ne!( + a.prefixed_key("shared").unwrap(), + b.prefixed_key("shared").unwrap() + ); } } diff --git a/crates/forge-runtime/src/mcp/registry.rs b/crates/forge-runtime/src/mcp/registry.rs index 9a143e10..d72da116 100644 --- a/crates/forge-runtime/src/mcp/registry.rs +++ b/crates/forge-runtime/src/mcp/registry.rs @@ -3,25 +3,10 @@ use std::future::Future; use std::pin::Pin; use std::sync::Arc; +use forge_core::util::normalize_handler_args as normalize_args; use forge_core::{ForgeMcpTool, McpToolContext, McpToolInfo, Result}; use serde_json::Value; -fn normalize_args(args: Value) -> Value { - let unwrapped = match args { - Value::Object(map) if map.len() == 1 => map - .get("args") - .or_else(|| map.get("input")) - .cloned() - .unwrap_or(Value::Object(map)), - other => other, - }; - - match unwrapped { - Value::Null => Value::Object(serde_json::Map::new()), - other => other, - } -} - pub type BoxedMcpToolFn = Arc< dyn Fn(&McpToolContext, Value) -> Pin> + Send + '_>> + Send diff --git a/crates/forge-runtime/src/observability/db.rs b/crates/forge-runtime/src/observability/db.rs index ec6b5e71..c3f419a1 100644 --- a/crates/forge-runtime/src/observability/db.rs +++ b/crates/forge-runtime/src/observability/db.rs @@ -3,7 +3,7 @@ use opentelemetry::metrics::{Gauge, Histogram}; use sqlx::PgPool; use std::sync::OnceLock; use std::time::{Duration, Instant}; -use tracing::{Instrument, info_span}; +use tracing::{Instrument, Level, debug_span, enabled, info_span}; const DB_SYSTEM: &str = "db.system"; const DB_OPERATION_NAME: &str = "db.operation.name"; @@ -90,43 +90,78 @@ pub fn record_query_duration(operation: &str, duration: Duration) { } /// Extract the table name from a simple SQL query, or `None` for complex ones. +/// +/// Walks the source by `char_indices` rather than fixed byte offsets so +/// non-ASCII identifiers (quoted Unicode columns/tables) can't panic the +/// slicer. `to_uppercase()` can change the byte length of a string, so we +/// can't reuse byte offsets discovered in the uppercased copy against the +/// original — locate keywords case-insensitively over the original instead. pub fn extract_table_name(sql: &str) -> Option<&str> { let sql = sql.trim(); - let upper = sql.to_uppercase(); - if upper.starts_with("SELECT") { - // SELECT ... FROM table_name ... - if let Some(from_pos) = upper.find(" FROM ") { - let after_from = &sql[from_pos + 6..]; - return extract_first_identifier(after_from.trim_start()); + if let Some(rest) = strip_keyword_prefix(sql, "INSERT INTO ") + .or_else(|| strip_keyword_prefix(sql, "DELETE FROM ")) + .or_else(|| strip_keyword_prefix(sql, "CREATE TABLE IF NOT EXISTS ")) + .or_else(|| strip_keyword_prefix(sql, "CREATE TABLE ")) + .or_else(|| strip_keyword_prefix(sql, "UPDATE ")) + { + return extract_first_identifier(rest.trim_start()); + } + + if strip_keyword_prefix(sql, "SELECT").is_some() { + // Find " FROM " case-insensitively without re-allocating a full + // uppercase copy whose byte length can diverge from the source. + if let Some(from_byte) = find_ci(sql, " FROM ") { + let after = sql.get(from_byte + " FROM ".len()..)?; + return extract_first_identifier(after.trim_start()); } - } else if upper.starts_with("INSERT INTO ") { - let after_into = &sql[12..]; - return extract_first_identifier(after_into.trim_start()); - } else if upper.starts_with("UPDATE ") { - let after_update = &sql[7..]; - return extract_first_identifier(after_update.trim_start()); - } else if upper.starts_with("DELETE FROM ") { - let after_from = &sql[12..]; - return extract_first_identifier(after_from.trim_start()); - } else if upper.starts_with("CREATE TABLE ") { - let after_table = if upper.starts_with("CREATE TABLE IF NOT EXISTS ") { - &sql[27..] - } else { - &sql[13..] - }; - return extract_first_identifier(after_table.trim_start()); } None } +fn strip_keyword_prefix<'a>(sql: &'a str, keyword: &str) -> Option<&'a str> { + if sql.len() < keyword.len() { + return None; + } + let prefix = sql.get(..keyword.len())?; + if prefix.eq_ignore_ascii_case(keyword) { + sql.get(keyword.len()..) + } else { + None + } +} + +/// Case-insensitive search for an ASCII needle. Returns the byte offset of +/// the first match in the source. +fn find_ci(haystack: &str, needle_ascii_upper: &str) -> Option { + let bytes = haystack.as_bytes(); + let n = needle_ascii_upper.as_bytes(); + if n.is_empty() || bytes.len() < n.len() { + return None; + } + 'outer: for start in 0..=bytes.len() - n.len() { + for (i, nb) in n.iter().enumerate() { + let hb = bytes.get(start + i)?; + if !hb.eq_ignore_ascii_case(nb) { + continue 'outer; + } + } + // Confirm the match begins on a UTF-8 char boundary so the caller's + // slice never bisects a multi-byte sequence. + if haystack.is_char_boundary(start) && haystack.is_char_boundary(start + n.len()) { + return Some(start); + } + } + None +} + fn extract_first_identifier(s: &str) -> Option<&str> { let end = s .find(|c: char| c.is_whitespace() || c == '(' || c == ',' || c == ';') .unwrap_or(s.len()); - if end > 0 { Some(&s[..end]) } else { None } + if end > 0 { s.get(..end) } else { None } } /// Execute a database operation with tracing and duration recording. @@ -134,7 +169,11 @@ pub async fn instrumented_query(operation: &str, table: Option<&str>, f where F: std::future::Future>, { - let span = if let Some(tbl) = table { + // Skip span allocation entirely when DEBUG isn't enabled — saves the + // ~few-hundred-ns alloc per query when the operator runs at warn/info. + let span = if !enabled!(Level::DEBUG) { + debug_span!("db.query") + } else if let Some(tbl) = table { info_span!( "db.query", db.system = DB_SYSTEM_POSTGRESQL, diff --git a/crates/forge-runtime/src/observability/metrics.rs b/crates/forge-runtime/src/observability/metrics.rs index 42632477..e386a3b0 100644 --- a/crates/forge-runtime/src/observability/metrics.rs +++ b/crates/forge-runtime/src/observability/metrics.rs @@ -2,7 +2,8 @@ use opentelemetry::{ KeyValue, global, metrics::{Counter, Gauge, Histogram, UpDownCounter}, }; -use std::sync::OnceLock; +use std::collections::HashSet; +use std::sync::{OnceLock, RwLock}; const METER_NAME: &str = "forge-runtime"; @@ -43,9 +44,10 @@ impl HttpMetrics { } pub fn record(&self, method: &str, path: &str, status: u16, duration_secs: f64) { + let normalized = normalize_path_for_metrics(path); let attributes = [ KeyValue::new("method", method.to_string()), - KeyValue::new("path", path.to_string()), + KeyValue::new("path", normalized), KeyValue::new("status", i64::from(status)), ]; @@ -54,6 +56,69 @@ impl HttpMetrics { } } +/// Soft cap on the number of distinct dynamic label values we'll let through +/// before collapsing the rest to ``. Keeps the OTel cardinality +/// bounded even if a poorly-controlled handler name leaks into the metric. +const MAX_DYNAMIC_LABELS: usize = 1000; + +static FN_LABELS_SEEN: OnceLock>> = OnceLock::new(); +static JOB_LABELS_SEEN: OnceLock>> = OnceLock::new(); + +fn capped_label(seen: &OnceLock>>, name: &str) -> String { + let set = seen.get_or_init(|| RwLock::new(HashSet::new())); + if let Ok(guard) = set.read() + && guard.contains(name) + { + return name.to_string(); + } + if let Ok(mut guard) = set.write() { + if guard.contains(name) { + return name.to_string(); + } + if guard.len() >= MAX_DYNAMIC_LABELS { + return "".to_string(); + } + guard.insert(name.to_string()); + return name.to_string(); + } + name.to_string() +} + +/// Map a concrete request path to a stable route template. Without route +/// info from the router, we apply heuristics: replace UUIDs and all-numeric +/// path segments with `:id`. Anything we don't recognize stays as-is so +/// known fixed routes (e.g. `/_api/health`) keep their identity. +pub fn normalize_path_for_metrics(path: &str) -> String { + fn is_dynamic(seg: &str) -> bool { + !seg.is_empty() && (looks_like_uuid(seg) || seg.chars().all(|c| c.is_ascii_digit())) + } + // The common case is a fixed route with nothing to rewrite — return it in a + // single allocation rather than rebuilding it segment by segment. + if !path.split('/').any(is_dynamic) { + return path.to_string(); + } + let mut out = String::with_capacity(path.len()); + for (i, seg) in path.split('/').enumerate() { + if i > 0 { + out.push('/'); + } + if is_dynamic(seg) { + out.push_str(":id"); + } else { + out.push_str(seg); + } + } + out +} + +fn looks_like_uuid(s: &str) -> bool { + s.len() == 36 + && s.as_bytes().iter().enumerate().all(|(i, b)| match i { + 8 | 13 | 18 | 23 => *b == b'-', + _ => b.is_ascii_hexdigit(), + }) +} + pub struct FnMetrics { executions_total: Counter, duration: Histogram, @@ -91,7 +156,7 @@ impl FnMetrics { ) { let status = if success { "ok" } else { "error" }; let attributes = [ - KeyValue::new("function", function.to_string()), + KeyValue::new("function", capped_label(&FN_LABELS_SEEN, function)), KeyValue::new("kind", kind.to_string()), KeyValue::new("status", status), KeyValue::new("cached", cached), @@ -130,7 +195,10 @@ impl FnCacheMetrics { } pub fn record(&self, function: &str, hit: bool) { - let attributes = [KeyValue::new("function", function.to_string())]; + let attributes = [KeyValue::new( + "function", + capped_label(&FN_LABELS_SEEN, function), + )]; if hit { self.hits_total.add(1, &attributes); } else { @@ -181,7 +249,7 @@ impl JobMetrics { pub fn record(&self, job_type: &str, status: &'static str, duration_secs: f64) { let attributes = [ - KeyValue::new("job_type", job_type.to_string()), + KeyValue::new("job_type", capped_label(&JOB_LABELS_SEEN, job_type)), KeyValue::new("status", status), ]; @@ -190,8 +258,13 @@ impl JobMetrics { } pub fn record_lost_claim(&self, job_type: &str) { - self.lost_claim_total - .add(1, &[KeyValue::new("job_type", job_type.to_string())]); + self.lost_claim_total.add( + 1, + &[KeyValue::new( + "job_type", + capped_label(&JOB_LABELS_SEEN, job_type), + )], + ); } } @@ -425,32 +498,67 @@ mod tests { use super::*; #[test] - fn test_http_metrics_creation() { - let _metrics = HttpMetrics::new(); + fn normalize_path_rewrites_uuid_segments() { + assert_eq!( + normalize_path_for_metrics("/_api/rpc/get_user/550e8400-e29b-41d4-a716-446655440000"), + "/_api/rpc/get_user/:id" + ); } #[test] - fn test_job_metrics_creation() { - let _metrics = JobMetrics::new(); + fn normalize_path_rewrites_all_digit_segments() { + assert_eq!(normalize_path_for_metrics("/users/12345"), "/users/:id"); + assert_eq!(normalize_path_for_metrics("/a/1/b/2"), "/a/:id/b/:id"); } #[test] - fn test_connections_gauge_creation() { - let _gauge = ActiveConnectionsGauge::new(); + fn normalize_path_leaves_fixed_routes_untouched() { + // No dynamic segment => returned verbatim (single-allocation fast path). + for fixed in ["/_api/health", "/_api/ready", "/_api/rpc/list_users", "/"] { + assert_eq!(normalize_path_for_metrics(fixed), fixed); + } + } + + #[test] + fn normalize_path_preserves_trailing_slash_shape() { + // A trailing slash makes an empty final segment, which is_dynamic + // treats as non-dynamic, so it is preserved. + assert_eq!(normalize_path_for_metrics("/users/12345/"), "/users/:id/"); + assert_eq!(normalize_path_for_metrics("/_api/health/"), "/_api/health/"); } #[test] - fn test_notify_metrics_creation() { - let _metrics = NotifyMetrics::new(); + fn normalize_path_does_not_rewrite_alphanumeric_or_short_hex() { + // 32-char hex (not a 36-char dashed UUID) and mixed alnum stay as-is. + assert_eq!(normalize_path_for_metrics("/items/abc123"), "/items/abc123"); + assert_eq!( + normalize_path_for_metrics("/x/0123456789abcdef0123456789abcdef"), + "/x/0123456789abcdef0123456789abcdef" + ); } #[test] - fn test_subscription_metrics_creation() { - let _metrics = SubscriptionMetrics::new(); + fn capped_label_passes_known_names_through_until_cap() { + // Use a fresh, test-local OnceLock so the global label sets aren't + // perturbed and the cap is exercised deterministically. + let seen: OnceLock>> = OnceLock::new(); + for i in 0..MAX_DYNAMIC_LABELS { + let name = format!("fn_{i}"); + assert_eq!(capped_label(&seen, &name), name); + } + // The set is now full; a brand-new name collapses to the other-bucket. + assert_eq!(capped_label(&seen, "fn_overflow"), ""); + // An already-seen name still passes through after the cap is hit. + assert_eq!(capped_label(&seen, "fn_0"), "fn_0"); } #[test] - fn test_workflow_scheduler_metrics_creation() { - let _metrics = WorkflowSchedulerMetrics::new(); + fn capped_label_is_idempotent_for_repeated_names() { + let seen: OnceLock>> = OnceLock::new(); + assert_eq!(capped_label(&seen, "get_user"), "get_user"); + assert_eq!(capped_label(&seen, "get_user"), "get_user"); + // Only one slot consumed for the repeated name. + let guard = seen.get().expect("init").read().expect("read lock"); + assert_eq!(guard.len(), 1); } } diff --git a/crates/forge-runtime/src/observability/telemetry.rs b/crates/forge-runtime/src/observability/telemetry.rs index 3a00269a..796b7e53 100644 --- a/crates/forge-runtime/src/observability/telemetry.rs +++ b/crates/forge-runtime/src/observability/telemetry.rs @@ -273,7 +273,8 @@ pub fn init_telemetry( ) } Err(e) => { - eprintln!("WARNING: OTLP trace exporter init failed, traces disabled: {e}"); + tracing::error!(error = %e, "OTLP trace exporter init failed; traces disabled"); + record_otel_export_initialized("traces", false); None } } @@ -306,7 +307,8 @@ pub fn init_telemetry( Some(log_layer) } Err(e) => { - eprintln!("WARNING: OTLP log exporter init failed, log bridge disabled: {e}"); + tracing::error!(error = %e, "OTLP log exporter init failed; log bridge disabled"); + record_otel_export_initialized("logs", false); None } } @@ -335,11 +337,22 @@ pub fn init_telemetry( global::set_meter_provider(meter_provider); } Err(e) => { - eprintln!("WARNING: OTLP metric exporter init failed, metrics disabled: {e}"); + tracing::error!(error = %e, "OTLP metric exporter init failed; metrics disabled"); + record_otel_export_initialized("metrics", false); } } } + if config.enable_traces { + record_otel_export_initialized("traces", TRACER_PROVIDER.get().is_some()); + } + if config.enable_logs { + record_otel_export_initialized("logs", LOGGER_PROVIDER.get().is_some()); + } + if config.enable_metrics { + record_otel_export_initialized("metrics", METER_PROVIDER.get().is_some()); + } + tracing::info!( service = %config.service_name, version = %config.service_version, @@ -353,6 +366,26 @@ pub fn init_telemetry( Ok(true) } +/// Health gauge that flips to 1 when an OTLP exporter initialized +/// successfully and 0 when it failed at startup. Lets operators alert on +/// missing telemetry without parsing log lines. +fn record_otel_export_initialized(signal: &'static str, initialized: bool) { + use opentelemetry::metrics::Gauge; + static GAUGE: OnceLock> = OnceLock::new(); + let gauge = GAUGE.get_or_init(|| { + global::meter("forge-runtime") + .u64_gauge("otel_export_initialized") + .with_description( + "1 if the OTLP exporter for this signal initialized at startup, 0 if it failed.", + ) + .build() + }); + gauge.record( + if initialized { 1 } else { 0 }, + &[KeyValue::new("signal", signal)], + ); +} + pub fn shutdown_telemetry() { tracing::info!("shutting down telemetry"); diff --git a/crates/forge-runtime/src/pg/change_log.rs b/crates/forge-runtime/src/pg/change_log.rs index 633d3eed..78ad59cf 100644 --- a/crates/forge-runtime/src/pg/change_log.rs +++ b/crates/forge-runtime/src/pg/change_log.rs @@ -176,9 +176,18 @@ mod integration_tests { let base = TestDatabase::from_env() .await .expect("Failed to create test database"); - base.isolated(test_name) + let db = base + .isolated(test_name) .await - .expect("Failed to create isolated db") + .expect("Failed to create isolated db"); + // forge_change_log and forge_notify_change() live in the system schema; + // an isolated DB starts empty, so apply it (mirrors queue.rs setup_db). + // Without this the tests fail with "relation forge_change_log does not + // exist" — which went unnoticed because this suite never ran in CI. + db.run_sql(&crate::pg::migration::get_all_system_sql()) + .await + .expect("Failed to apply system schema"); + db } /// Create a tracked table that fires the change trigger on every write. @@ -234,23 +243,47 @@ mod integration_tests { #[tokio::test] async fn trim_deletes_only_rows_older_than_cutoff() { let db = setup_db("change_log_trim").await; - install_tracked_table(db.pool(), "trim_items").await; - sqlx::query("INSERT INTO trim_items (id, name) VALUES (gen_random_uuid(), 'old')") - .execute(db.pool()) - .await - .unwrap(); - // Forge a future cutoff so the row qualifies as "old". - let cutoff = Utc::now() + chrono::Duration::seconds(10); - let deleted = trim_change_log(db.pool(), cutoff).await.unwrap(); - assert_eq!(deleted, 1); - let remaining: i64 = sqlx::query_scalar( - "SELECT COUNT(*) FROM forge_change_log WHERE table_name='trim_items'", + // trim_change_log enforces a retention floor: it is a no-op until the log + // exceeds CHANGE_LOG_MIN_ROWS, so it never over-trims a small log. Seed + // just past the floor with OLD rows plus a handful of recent ones, then + // trim at a cutoff between them — only the old rows may be deleted. + // Insert directly (the trigger path is covered by the drain test) so the + // created_at timestamps are controllable. + let old_count = CHANGE_LOG_MIN_ROWS + 50; + sqlx::query( + "INSERT INTO forge_change_log (table_name, op, created_at) + SELECT 'trim_items', 'INSERT', NOW() - INTERVAL '2 days' + FROM generate_series(1, $1)", ) - .fetch_one(db.pool()) + .bind(old_count) + .execute(db.pool()) .await .unwrap(); - assert_eq!(remaining, 0); + sqlx::query( + "INSERT INTO forge_change_log (table_name, op, created_at) + SELECT 'trim_items', 'INSERT', NOW() + FROM generate_series(1, 5)", + ) + .execute(db.pool()) + .await + .unwrap(); + + let cutoff = Utc::now() - chrono::Duration::days(1); + let deleted = trim_change_log(db.pool(), cutoff).await.unwrap(); + assert_eq!( + deleted, old_count as u64, + "every row older than the cutoff must be trimmed once past the floor", + ); + + let remaining: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM forge_change_log") + .fetch_one(db.pool()) + .await + .unwrap(); + assert_eq!( + remaining, 5, + "rows newer than the cutoff must survive the trim" + ); } #[tokio::test] diff --git a/crates/forge-runtime/src/pg/leader.rs b/crates/forge-runtime/src/pg/leader.rs index 875d0f07..f1e3a5d6 100644 --- a/crates/forge-runtime/src/pg/leader.rs +++ b/crates/forge-runtime/src/pg/leader.rs @@ -12,6 +12,14 @@ use crate::pg::notify_bus::PgNotifyBus; /// Payload is the role string; subscribers filter by their own role. pub const LEADER_RELEASED_CHANNEL: &str = "forge_leader_released"; +/// Number of times to retry `pg_try_advisory_lock` after terminating a zombie +/// leader's backend. PostgreSQL releases the dead backend's advisory locks +/// asynchronously, so the first retry can race the teardown. +const PREEMPT_RETRY_ATTEMPTS: u32 = 10; + +/// Backoff between post-termination lock-acquisition retries. +const PREEMPT_RETRY_BACKOFF: Duration = Duration::from_millis(25); + /// Leader election configuration. #[derive(Debug, Clone)] pub struct LeaderConfig { @@ -132,6 +140,19 @@ impl LeaderElection { self.is_leader.load(Ordering::SeqCst) } + /// Subscribe to leader-released NOTIFY events for this role, if a notify + /// bus is attached. Returns `None` when no bus is configured (single-node + /// or test setups), in which case callers should fall back to polling. + /// + /// Standby polling loops use this to wake immediately when the current + /// leader voluntarily releases, instead of sleeping for the full + /// `check_interval`. + pub fn subscribe_release_notify(&self) -> Option> { + self.notify_bus + .as_ref() + .and_then(|bus| bus.subscribe(LEADER_RELEASED_CHANNEL)) + } + /// How often the leader validates the advisory lock is still held. pub fn lock_validate_interval(&self) -> Duration { self.config.lock_validate_interval @@ -307,6 +328,38 @@ impl LeaderElection { } }; + // Verify the lock-holding backend belongs to a forge process before + // terminating it: two unrelated apps sharing this DB can hash to the + // same advisory lock ID, and we must not evict the other app's session. + // Connections without an application_name are assumed non-forge and + // skipped — operators can set `application_name=forge-` to opt in. + // Untyped query: `pg_stat_activity` rows can come and go between + // statements (the holder may have exited), so the macro's static row + // shape buys nothing here. Allow the lint locally. + #[allow(clippy::disallowed_methods)] + let app_name: Option = sqlx::query_scalar::<_, Option>( + "SELECT application_name FROM pg_stat_activity WHERE pid = $1", + ) + .bind(pid) + .fetch_optional(&mut **conn) + .await + .map_err(forge_core::ForgeError::Database)? + .flatten(); + + match app_name.as_deref() { + Some(name) if name.starts_with("forge") => {} + other => { + tracing::warn!( + role = self.role.as_str(), + zombie_pid = pid, + application_name = ?other, + "Refusing to terminate backend whose application_name does not start with 'forge'; \ + another app may share this database. Set application_name=forge- to allow preemption." + ); + return Ok(false); + } + } + // pg_terminate_backend returns false when permission is denied or the backend is already gone. let terminated = sqlx::query_scalar!(r#"SELECT pg_terminate_backend($1) AS "terminated!""#, pid,) @@ -332,18 +385,31 @@ impl LeaderElection { "Terminated zombie leader backend with expired lease; retrying lock acquisition" ); - // Yield to let PG process the termination before retrying the lock. - tokio::task::yield_now().await; + // pg_terminate_backend only *signals* the backend; PostgreSQL releases + // its advisory locks asynchronously as that backend tears down. A single + // immediate retry races that teardown and usually loses, so we poll + // pg_try_advisory_lock a few times with a short backoff. The window is + // small (PG processes the signal in milliseconds) but a bare yield is + // not enough. + for attempt in 0..PREEMPT_RETRY_ATTEMPTS { + let acquired = sqlx::query_scalar!( + r#"SELECT pg_try_advisory_lock($1) AS "acquired!""#, + self.role.lock_id(), + ) + .fetch_one(&mut **conn) + .await + .map_err(forge_core::ForgeError::Database)?; - let acquired = sqlx::query_scalar!( - r#"SELECT pg_try_advisory_lock($1) AS "acquired!""#, - self.role.lock_id(), - ) - .fetch_one(&mut **conn) - .await - .map_err(forge_core::ForgeError::Database)?; + if acquired { + return Ok(true); + } - Ok(acquired) + if attempt + 1 < PREEMPT_RETRY_ATTEMPTS { + tokio::time::sleep(PREEMPT_RETRY_BACKOFF).await; + } + } + + Ok(false) } /// Confirm the advisory lock is still held on the lock-owning connection. @@ -567,24 +633,9 @@ impl LeaderElection { // resolved by the lock being gone). let mut lock_connection = self.lock_connection.lock().await; if let Some(mut conn) = lock_connection.take() { - // Emit NOTIFY before unlock so standbys wake only when the lock is - // genuinely about to be free. Failure is non-fatal: standbys fall - // back to their normal check_interval timer. - if let Err(e) = sqlx::query!( - "SELECT pg_notify($1, $2)", - LEADER_RELEASED_CHANNEL, - self.role.as_str(), - ) - .execute(&mut *conn) - .await - { - tracing::warn!( - role = self.role.as_str(), - error = %e, - "Failed to emit leader-released NOTIFY; standbys will wait for next check tick", - ); - } - + // Unlock first, then NOTIFY only if we actually held the lock. + // Notifying when the lock wasn't held wakes standbys to race for + // a slot we never owned in the first place — pure noise. let released = sqlx::query_scalar!( "SELECT pg_advisory_unlock($1) as \"released!\"", self.role.lock_id() @@ -593,11 +644,26 @@ impl LeaderElection { .await .map_err(forge_core::ForgeError::Database)?; - if !released { + if released { + if let Err(e) = sqlx::query!( + "SELECT pg_notify($1, $2)", + LEADER_RELEASED_CHANNEL, + self.role.as_str(), + ) + .execute(&mut *conn) + .await + { + tracing::warn!( + role = self.role.as_str(), + error = %e, + "Failed to emit leader-released NOTIFY; standbys will wait for next check tick", + ); + } + } else { tracing::warn!( role = self.role.as_str(), "pg_advisory_unlock returned false during release; \ - lock was not held by this session" + lock was not held by this session; skipping NOTIFY" ); } @@ -1082,6 +1148,20 @@ mod integration_tests { assert!(zombie.try_become_leader().await.unwrap()); assert!(zombie.is_leader()); + // Tag the zombie's lock-holding backend with a forge-prefixed + // application_name, matching what the connection pool now sets in + // production (`forge-`). The preemption guard only terminates + // backends whose application_name starts with `forge`; without this the + // simulated zombie reports an empty name and would never be evicted. + { + let mut conn_guard = zombie.lock_connection.lock().await; + let conn = conn_guard.as_mut().expect("lock connection present"); + sqlx::query("SET application_name = 'forge-demo'") + .execute(&mut **conn) + .await + .unwrap(); + } + // Artificially expire the lease so standbys see a stale leader. #[allow(clippy::disallowed_methods)] sqlx::query( diff --git a/crates/forge-runtime/src/pg/migration/runner.rs b/crates/forge-runtime/src/pg/migration/runner.rs index 296a4875..736dd12b 100644 --- a/crates/forge-runtime/src/pg/migration/runner.rs +++ b/crates/forge-runtime/src/pg/migration/runner.rs @@ -322,11 +322,15 @@ impl MigrationRunner { &self, conn: &mut sqlx::pool::PoolConnection, ) -> Result<()> { - sqlx::query_scalar!("SELECT pg_advisory_unlock($1)", MIGRATION_LOCK_ID) - .fetch_one(&mut **conn) - .await - .map_err(|e| ForgeError::internal_with("Failed to release migration lock", e))?; - debug!("Migration lock released"); + let released: Option = + sqlx::query_scalar!("SELECT pg_advisory_unlock($1)", MIGRATION_LOCK_ID) + .fetch_one(&mut **conn) + .await + .map_err(|e| ForgeError::internal_with("Failed to release migration lock", e))?; + // `false` means this session didn't hold the lock — useful diagnostic + // for connection-pooler scenarios where the lock-holding backend was + // reused before release. Log it instead of silently dropping. + debug!(released = ?released, "Migration lock released"); Ok(()) } @@ -444,15 +448,24 @@ impl MigrationRunner { /// `ALTER TYPE ... ADD VALUE`, `VACUUM`, and `REINDEX CONCURRENTLY` /// inside a transaction block, so opt-in migrations skip the BEGIN. /// - /// Tradeoffs the migration author must accept: - /// - A partial failure leaves the schema half-applied and the - /// bookkeeping row missing, so the next run will retry from the top. - /// - Even if all DDL succeeds, the bookkeeping `INSERT` runs on a - /// *fresh* pool connection — if that insert fails, the migration is - /// re-run on the next startup despite already having taken effect. + /// Inherent risk window: DDL commits as each statement runs, but the + /// bookkeeping `INSERT` into `forge_system_migrations` is a separate + /// statement. If the process or the connection dies between the last + /// DDL statement and the INSERT, the schema is migrated but no row is + /// recorded, and the next boot will try to re-apply the migration. + /// + /// To shrink (but not close) that window, the DDL and the bookkeeping + /// INSERT run on the **same** pooled connection — we never hand the + /// connection back to the pool between them, so a healthy connection + /// stays healthy across both steps. A mid-run crash or network drop + /// is still possible; this is the price of skipping the transaction. /// - /// Migrations using this mode must be authored idempotently - /// (`IF NOT EXISTS`, `ADD VALUE IF NOT EXISTS`, and so on). + /// Migrations using this mode **must** be authored idempotently + /// (`CREATE INDEX CONCURRENTLY IF NOT EXISTS`, `ADD VALUE IF NOT + /// EXISTS`, and so on) so a retry on the next boot is a no-op against + /// already-applied schema. A future improvement could detect "schema + /// applied but row missing" at boot and back-fill the bookkeeping row + /// instead of re-running the SQL. async fn apply_non_transactional(&self, migration: &Migration) -> Result<()> { info!( "Applying non-transactional migration: {}", @@ -502,6 +515,30 @@ impl MigrationRunner { } .await; + // Record bookkeeping on the SAME connection used for the DDL. A + // fresh pool connection here would widen the failure window — if + // the new acquire failed, the schema would already be migrated + // with no row to prove it. + let record_result: Result<()> = if exec_result.is_ok() { + let checksum = crate::stable_hash::sha256_hex(migration.up_sql.as_bytes()); + sqlx::query!( + "INSERT INTO forge_system_migrations (version, checksum) VALUES ($1, $2)", + migration.version, + checksum, + ) + .execute(&mut *conn) + .await + .map(|_| ()) + .map_err(|e| { + ForgeError::internal_with( + format!("Failed to record migration '{}'", migration.version), + e, + ) + }) + } else { + Ok(()) + }; + // Always reset before returning the connection to the pool — even on // failure. A failed RESET is rare but operators need visibility into it. if let Err(e) = sqlx::query("RESET lock_timeout").execute(&mut *conn).await { @@ -516,21 +553,7 @@ impl MigrationRunner { drop(conn); exec_result?; - - let checksum = crate::stable_hash::sha256_hex(migration.up_sql.as_bytes()); - sqlx::query!( - "INSERT INTO forge_system_migrations (version, checksum) VALUES ($1, $2)", - migration.version, - checksum, - ) - .execute(&self.pool) - .await - .map_err(|e| { - ForgeError::internal_with( - format!("Failed to record migration '{}'", migration.version), - e, - ) - })?; + record_result?; info!( "Non-transactional migration applied: {} ({:?})", @@ -1286,10 +1309,20 @@ mod integration_tests { assert!(concurrent.transactional); let err = runner.run(vec![setup, concurrent]).await.unwrap_err(); - let msg = err.to_string(); + // The wrapping ForgeError's Display shows only its context ("Failed to + // apply migration ..."); PG's actual reason ("cannot run inside a + // transaction block") lives in the source chain. Walk it so we assert on + // the real rejection — and prove the runner carries the cause, not drops it. + let mut chain = err.to_string(); + let mut source = std::error::Error::source(&err); + while let Some(cause) = source { + chain.push_str(": "); + chain.push_str(&cause.to_string()); + source = cause.source(); + } assert!( - msg.contains("CONCURRENTLY") || msg.to_lowercase().contains("transaction"), - "expected PG to reject concurrent index in tx, got: {msg}" + chain.contains("CONCURRENTLY") || chain.to_lowercase().contains("transaction"), + "expected PG to reject concurrent index in tx, got chain: {chain}" ); } diff --git a/crates/forge-runtime/src/pg/mod.rs b/crates/forge-runtime/src/pg/mod.rs index cc0bcfe9..71dda509 100644 --- a/crates/forge-runtime/src/pg/mod.rs +++ b/crates/forge-runtime/src/pg/mod.rs @@ -13,6 +13,6 @@ pub use migration::{ AppliedMigration, DriftStatus, Migration, MigrationRunner, MigrationStatus, load_migrations_from_dir, }; -pub use notify::{MAX_PAYLOAD_BYTES, NotifyChannel}; +pub use notify::{MAX_PAYLOAD_BYTES, NotifyChannel, NotifyStreamError}; pub use notify_bus::PgNotifyBus; pub use pool::Database; diff --git a/crates/forge-runtime/src/pg/notify.rs b/crates/forge-runtime/src/pg/notify.rs index 777f51ae..132b51a7 100644 --- a/crates/forge-runtime/src/pg/notify.rs +++ b/crates/forge-runtime/src/pg/notify.rs @@ -84,6 +84,13 @@ where /// - `ForgeError::Serialization` if `serde_json::to_string(payload)` fails. /// - `ForgeError::InvalidArgument` if the serialized payload exceeds /// [`MAX_PAYLOAD_BYTES`]. Use the change-log fallback for larger bodies. + /// + /// **Note**: this cap only applies to publishers that route through this + /// method. Server-side triggers that build their own payloads in PL/pgSQL + /// bypass the check entirely. Exceeding the 8 KiB PostgreSQL limit there + /// aborts the trigger's wrapping transaction (typically the user mutation + /// that caused the trigger to fire). Trigger authors must enforce their + /// own bounds. /// - `ForgeError::Database` if the underlying `SELECT pg_notify(...)` /// fails (transaction rolled back, connection dropped, etc.). pub async fn publish<'e, E>(&self, executor: E, payload: &T) -> Result<()> @@ -110,6 +117,29 @@ where } } +/// Reason a [`NotifyChannel`] subscription terminated mid-stream. +/// +/// Previously the stream simply ended via `take_while(is_ok)`, leaving the +/// caller with no way to distinguish a deliberate close from a PG-side error. +/// Items now carry `Result` so consumers can decide +/// whether to reconnect, surface the error, or treat it as fatal. +#[derive(Debug)] +pub enum NotifyStreamError { + /// The underlying `PgListener::recv` returned an error. Typically a + /// dropped backend connection — callers should reconnect. + Recv(sqlx::Error), +} + +impl std::fmt::Display for NotifyStreamError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Recv(e) => write!(f, "PgListener recv failed: {e}"), + } + } +} + +impl std::error::Error for NotifyStreamError {} + impl NotifyChannel where T: DeserializeOwned + Send + 'static, @@ -118,38 +148,53 @@ where /// /// `listener` is consumed; the caller surrenders the connection to the /// stream for the duration of the subscription. Notifications whose - /// payload fails JSON decoding are logged and skipped, so a malformed - /// publish from one peer cannot tear down a long-running subscriber. - /// Errors from the underlying `recv` (connection dropped, etc.) end the - /// stream; the caller decides whether to reconnect. - pub async fn subscribe(&self, mut listener: PgListener) -> Result> { + /// payload fails JSON decoding are logged and skipped (a malformed + /// publish from one peer cannot tear down a long-running subscriber). + /// + /// Recv errors (connection dropped, etc.) are surfaced as + /// `Err(NotifyStreamError::Recv)` so the caller can distinguish a + /// graceful close (stream ends naturally) from a fault that requires + /// reconnect. After yielding an error the stream terminates. + pub async fn subscribe( + &self, + mut listener: PgListener, + ) -> Result>> { listener .listen(self.name) .await .map_err(ForgeError::Database)?; let channel_name = self.name; let raw = listener.into_stream(); + // Pass recv errors through (mapped to NotifyStreamError) and drop + // malformed payloads silently — the latter would otherwise look like + // a fault to subscribers when it's just a bad publish. let stream = raw - .take_while(|res| { - let cont = res.is_ok(); - async move { cont } - }) - .filter_map(move |res| async move { - let notification = match res { - Ok(n) => n, - Err(_) => return None, - }; - match serde_json::from_str::(notification.payload()) { - Ok(value) => Some(value), + .scan(false, |ended, res| { + let done = *ended; + let next = match res { + Ok(n) => Some(Ok(n)), Err(e) => { - tracing::debug!( - channel = channel_name, - error = %e, - payload = notification.payload(), - "NotifyChannel: dropping malformed payload", - ); - None + *ended = true; + Some(Err(NotifyStreamError::Recv(e))) } + }; + async move { if done { None } else { next } } + }) + .filter_map(move |res| async move { + match res { + Err(e) => Some(Err(e)), + Ok(notification) => match serde_json::from_str::(notification.payload()) { + Ok(value) => Some(Ok(value)), + Err(e) => { + tracing::debug!( + channel = channel_name, + error = %e, + payload = notification.payload(), + "NotifyChannel: dropping malformed payload", + ); + None + } + }, } }); Ok(stream) @@ -239,7 +284,8 @@ mod integration_tests { let received = tokio::time::timeout(Duration::from_secs(5), stream.next()) .await .expect("stream did not yield within 5s") - .expect("stream ended before yielding"); + .expect("stream ended before yielding") + .expect("recv ok"); assert_eq!(received, payload); } @@ -285,7 +331,8 @@ mod integration_tests { let received = tokio::time::timeout(Duration::from_secs(5), stream.next()) .await .expect("stream did not yield within 5s") - .expect("stream ended before yielding"); + .expect("stream ended before yielding") + .expect("recv ok"); assert_eq!(received, payload); } } diff --git a/crates/forge-runtime/src/pg/notify_bus.rs b/crates/forge-runtime/src/pg/notify_bus.rs index 55dda8eb..bd617aa7 100644 --- a/crates/forge-runtime/src/pg/notify_bus.rs +++ b/crates/forge-runtime/src/pg/notify_bus.rs @@ -15,6 +15,24 @@ //! old per-subsystem listeners had, except now there is exactly one //! reconnect path to maintain). //! +//! # Replay backing per channel +//! +//! Not every channel survives a reconnect gap equally well. Subscribers must +//! pair the bus with channel-specific recovery: +//! +//! - `forge_changes`: backed by `forge_change_log`. Subscribers replay missed +//! rows by `last_seen_seq` after a `subscribe_reconnects` tick. +//! - `forge_workflow_wakeup`: idempotent — the workflow executor's normal +//! timer poll catches missed wakeups within its tick interval. +//! - `forge_leader_released`: a missed event delays standby acquisition by at +//! most one `LeaderConfig::check_interval`. +//! - `forge_jobs_available`: **no replay backing**. A missed NOTIFY leaves +//! jobs unclaimed until the next worker poll (`poll_interval`, default +//! 5 s). Workers must keep their independent poll cadence even when this +//! channel is connected; do not extend `poll_interval` past acceptable +//! tail latency on the assumption that NOTIFY will always be timely. +//! - `forge_schema_changed`: advisory only; reconnect re-fetches schema. +//! //! # Payload semantics //! //! The bus forwards the raw `notification.payload()` string. Channels that @@ -32,7 +50,12 @@ use tokio::sync::{broadcast, watch}; /// Per-channel broadcast buffer size. Subscribers that fall behind by more /// than this many messages will see `RecvError::Lagged` and can decide /// whether to catch up or resync. -const CHANNEL_BUFFER_SIZE: usize = 256; +/// +/// Sized for bursty `forge_changes` workloads where a single transaction can +/// emit hundreds of notifications. Every direct subscriber MUST handle +/// `broadcast::error::RecvError::Lagged` — drop-and-resync is the only +/// correct response since the bus does not back-pressure publishers. +const CHANNEL_BUFFER_SIZE: usize = 4096; /// Initial reconnection delay after a `PgListener` disconnect. const INITIAL_BACKOFF: Duration = Duration::from_millis(500); diff --git a/crates/forge-runtime/src/pg/pool.rs b/crates/forge-runtime/src/pg/pool.rs index ed72f70e..05551ff3 100644 --- a/crates/forge-runtime/src/pg/pool.rs +++ b/crates/forge-runtime/src/pg/pool.rs @@ -173,7 +173,7 @@ impl Database { fn connect_options(url: &str, service_name: &str) -> sqlx::Result { let options: PgConnectOptions = url.parse()?; Ok(options - .application_name(service_name) + .application_name(&forge_application_name(service_name)) .log_statements(LevelFilter::Off) .log_slow_statements(LevelFilter::Warn, Duration::from_millis(500))) } @@ -185,7 +185,7 @@ impl Database { ) -> sqlx::Result { let options: PgConnectOptions = url.parse()?; let mut opts = options - .application_name(service_name) + .application_name(&forge_application_name(service_name)) .log_statements(LevelFilter::Off) .log_slow_statements(LevelFilter::Warn, Duration::from_millis(500)); if statement_timeout_secs > 0 { @@ -342,6 +342,25 @@ impl Database { } } +/// Build the `application_name` reported by every Forge connection. +/// +/// Leader-election zombie preemption ([`crate::pg::leader`]) only terminates a +/// lock-holding backend whose `application_name` starts with `forge`, so it +/// never evicts an unrelated app sharing the database. For that guard to ever +/// fire against Forge's *own* zombie, Forge connections must self-identify with +/// that prefix. The service name passed in is the project name (e.g. `demo`), +/// which would otherwise produce a non-matching `application_name`. +/// +/// Idempotent: a service name already starting with `forge` (e.g. the internal +/// `"forge"` default) is returned unchanged so we never produce `forge-forge`. +fn forge_application_name(service_name: &str) -> String { + if service_name.starts_with("forge") { + service_name.to_string() + } else { + format!("forge-{service_name}") + } +} + /// Minimum supported PostgreSQL major version. /// /// Forge v0.9+ uses features (skip-locked semantics with `NOWAIT`, partitioned @@ -466,4 +485,60 @@ mod tests { assert_eq!(cloned.url(), config.url()); assert_eq!(cloned.pool_size, config.pool_size); } + + #[test] + fn forge_application_name_prefixes_project_names() { + // A bare project name must gain the `forge-` prefix so leader-election + // zombie preemption (which only terminates `forge`-prefixed backends) + // can evict Forge's own zombie leader. + assert_eq!(forge_application_name("demo"), "forge-demo"); + assert_eq!(forge_application_name("my-app"), "forge-my-app"); + // Idempotent: names already starting with `forge` are untouched. + assert_eq!(forge_application_name("forge"), "forge"); + assert_eq!(forge_application_name("forge-worker"), "forge-worker"); + } +} + +#[cfg(all(test, feature = "testcontainers"))] +#[allow(clippy::unwrap_used, clippy::disallowed_methods)] +mod integration_tests { + use super::*; + use forge_core::testing::TestDatabase; + + async fn base_db() -> TestDatabase { + TestDatabase::from_env() + .await + .expect("Failed to create test database") + } + + /// A pool built with a production-shaped service name (the project name) + /// must report a `forge`-prefixed `application_name`. This is the precise + /// regressor for zombie-leader eviction: the leader-election guard only + /// terminates backends whose `application_name` starts with `forge`, so if + /// Forge's own pools reported the bare project name (`demo`) the framework + /// could never preempt its own zombie. Fails before the fix (reports + /// `demo`), passes after (reports `forge-demo`). + #[tokio::test] + async fn pool_application_name_is_forge_prefixed_for_preemption() { + let base = base_db().await; + let db = Database::from_config_with_service(&DatabaseConfig::new(base.url()), "demo") + .await + .expect("connect with production-shaped service name"); + + let app_name: String = sqlx::query_scalar("SELECT current_setting('application_name')") + .fetch_one(db.primary()) + .await + .unwrap(); + + assert_eq!( + app_name, "forge-demo", + "Forge pools must self-identify as forge- so leader preemption can evict them" + ); + assert!( + app_name.starts_with("forge"), + "application_name must satisfy the leader.rs preemption guard" + ); + + db.close().await; + } } diff --git a/crates/forge-runtime/src/rate_limit/limiter.rs b/crates/forge-runtime/src/rate_limit/limiter.rs index 06611cde..0cb30adc 100644 --- a/crates/forge-runtime/src/rate_limit/limiter.rs +++ b/crates/forge-runtime/src/rate_limit/limiter.rs @@ -75,7 +75,8 @@ impl StrictRateLimiter { // tokens is clamped to >= -1, so retry_after is bounded by // (1 - (-1)) / refill_rate = 2 / refill_rate — proportional to // one refill interval rather than runaway. - let retry_after = Duration::from_secs_f64((1.0 - tokens) / refill_rate); + let base = (1.0 - tokens) / refill_rate; + let retry_after = Duration::from_secs_f64(jittered(base)); Ok(RateLimitResult::denied(remaining, reset_at, retry_after)) } } @@ -179,6 +180,21 @@ impl StrictRateLimiter { } } +/// Apply ±25% jitter to a retry-after value so clients denied in the same +/// tick don't synchronize their retries into a thundering herd. +fn jittered(base_secs: f64) -> f64 { + if !base_secs.is_finite() || base_secs <= 0.0 { + return base_secs.max(0.0); + } + let nanos = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.subsec_nanos()) + .unwrap_or(0); + // Map nanos -> [-0.25, 0.25]. + let frac = ((nanos as f64) / 1_000_000_000.0) * 0.5 - 0.25; + (base_secs * (1.0 + frac)).max(0.0) +} + struct LocalBucket { tokens: f64, max_tokens: f64, @@ -267,8 +283,18 @@ impl HybridRateLimiter { let max_tokens = config.requests as f64; let refill_rate = config.refill_rate(); - if self.local.len() > self.max_local_buckets { - self.cleanup_local(Duration::from_secs(300)); // evict entries idle > 5 min + // Sweep proactively once we cross 75% of the soft cap so a burst of + // unique keys can't grow the map far past `max_local_buckets` before + // the first eviction runs. If everything is still hot we fall back to + // a hard cap that bounds memory at 2× the configured ceiling. + let len = self.local.len(); + if len > self.max_local_buckets * 3 / 4 { + self.cleanup_local(Duration::from_secs(300)); + if self.local.len() > self.max_local_buckets * 2 { + // Last-resort: drop entries idle > 30s to bound memory even + // when the workload is fully active on unique keys. + self.cleanup_local(Duration::from_secs(30)); + } } let mut bucket = self @@ -284,7 +310,8 @@ impl HybridRateLimiter { if allowed { Ok(RateLimitResult::allowed(remaining, reset_at)) } else { - let retry_after = bucket.time_until_token(); + let retry_after = + Duration::from_secs_f64(jittered(bucket.time_until_token().as_secs_f64())); Ok(RateLimitResult::denied(remaining, reset_at, retry_after)) } } diff --git a/crates/forge-runtime/src/realtime/listener.rs b/crates/forge-runtime/src/realtime/listener.rs index d5f8126a..574560cb 100644 --- a/crates/forge-runtime/src/realtime/listener.rs +++ b/crates/forge-runtime/src/realtime/listener.rs @@ -128,7 +128,16 @@ impl ChangeListener { }; let count = rows.len(); + // Track the highest seq across rows (including ones we skip because + // the op didn't parse) and only commit it after the whole batch is + // forwarded. Per-row stores let a live `rx.recv` racing this loop + // bump last_seq past unforwarded replay rows, and skipped rows used + // to leave a permanent gap that blocked future replays. + let mut max_seq = self.last_seq.load(Ordering::Relaxed); for row in &rows { + if row.seq > max_seq { + max_seq = row.seq; + } let Ok(operation) = row.op.parse::() else { continue; }; @@ -145,7 +154,9 @@ impl ChangeListener { } let _ = self.change_tx.send(change); - self.last_seq.store(row.seq, Ordering::Relaxed); + } + if max_seq > self.last_seq.load(Ordering::Relaxed) { + self.last_seq.store(max_seq, Ordering::Relaxed); } if count > 0 { @@ -228,6 +239,14 @@ impl ChangeListener { match result { Ok(payload) => { let recv_time = std::time::Instant::now(); + // Always recover any embedded seq, even on parse + // failure or unknown op, so a malformed/unknown + // payload doesn't pin the watermark and force + // the next reconnect-replay to refuse the gap. + let trailing_seq = payload + .rsplit_once('#') + .and_then(|(_, s)| s.parse::().ok()) + .unwrap_or(0); if let Some((change, seq)) = self.parse_notification(&payload) { // Skip already-processed seqs to prevent // double-processing during the seed window. @@ -243,6 +262,11 @@ impl ChangeListener { crate::cluster::metrics::record_notification_latency(recv_time.elapsed().as_secs_f64()); } else { tracing::debug!(payload = %payload, "Failed to parse notification"); + if trailing_seq > 0 + && trailing_seq > self.last_seq.load(Ordering::Relaxed) + { + self.last_seq.store(trailing_seq, Ordering::Relaxed); + } } } Err(broadcast::error::RecvError::Lagged(n)) => { diff --git a/crates/forge-runtime/src/realtime/manager.rs b/crates/forge-runtime/src/realtime/manager.rs index a5b12d65..b19fcbba 100644 --- a/crates/forge-runtime/src/realtime/manager.rs +++ b/crates/forge-runtime/src/realtime/manager.rs @@ -111,14 +111,21 @@ impl SubscriptionManager { table_deps: &'static [&'static str], selected_cols: &'static [&'static str], ) -> forge_core::Result<(QueryGroupId, SubscriptionId, bool)> { - // Check per-session limit - if let Some(subs) = self.session_subscribers.get(&session_id) - && subs.len() >= self.max_per_session + // Reserve a slot under the entry write guard so two concurrent + // subscribes from the same session can't both observe `len < max` + // and race past the limit. We hold the slot for the rest of this + // call; if anything later fails we drop the placeholder so the + // user gets their seat back. + let placeholder = SubscriberId(u32::MAX); { - return Err(forge_core::ForgeError::Validation(format!( - "Maximum subscriptions per session ({}) exceeded", - self.max_per_session - ))); + let mut entry = self.session_subscribers.entry(session_id).or_default(); + if entry.len() >= self.max_per_session { + return Err(forge_core::ForgeError::Validation(format!( + "Maximum subscriptions per session ({}) exceeded", + self.max_per_session + ))); + } + entry.push(placeholder); } let auth_scope = AuthScope::from_auth(auth_context); @@ -176,10 +183,16 @@ impl SubscriptionManager { group.subscribers.push(subscriber_id); } - self.session_subscribers - .entry(session_id) - .or_default() - .push(subscriber_id); + // Swap the placeholder we reserved earlier for the real id. If for + // any reason the placeholder is gone (e.g. concurrent session + // teardown), fall back to pushing. + if let Some(mut entry) = self.session_subscribers.get_mut(&session_id) { + if let Some(slot) = entry.iter_mut().find(|s| s.0 == u32::MAX) { + *slot = subscriber_id; + } else { + entry.push(subscriber_id); + } + } Ok((group_id, subscription_id, is_new)) } @@ -339,17 +352,24 @@ impl SubscriptionManager { /// runtime-discovered tables from the read set that weren't in the /// compile-time `table_deps`. pub fn update_group(&self, group_id: QueryGroupId, read_set: ReadSet, result_hash: String) { - if let Some(mut group) = self.groups.get_mut(&group_id) { - for table in &read_set.tables { - let already_indexed = group.table_deps.iter().any(|t| *t == table); - if !already_indexed { - self.table_index - .entry(table.clone()) - .or_default() - .insert(group_id); - } - } + // Record execution BEFORE extending the table_index so a concurrent + // `find_affected_groups` can't observe the new table in the index + // but still see the old read_set on the group (which would make + // `should_invalidate` return false and silently drop the change). + let new_tables: Vec = if let Some(mut group) = self.groups.get_mut(&group_id) { + let tables = read_set + .tables + .iter() + .filter(|t| !group.table_deps.contains(&t.as_str())) + .cloned() + .collect(); group.record_execution(read_set, result_hash); + tables + } else { + return; + }; + for table in new_tables { + self.table_index.entry(table).or_default().insert(group_id); } } @@ -368,16 +388,15 @@ impl SubscriptionManager { data: std::sync::Arc, serialized_len: usize, ) { - if let Some(mut group) = self.groups.get_mut(&group_id) { - for table in &read_set.tables { - let already_indexed = group.table_deps.iter().any(|t| *t == table); - if !already_indexed { - self.table_index - .entry(table.clone()) - .or_default() - .insert(group_id); - } - } + // Same ordering as `update_group`: record execution before + // publishing new tables into `table_index`. + let new_tables: Vec = if let Some(mut group) = self.groups.get_mut(&group_id) { + let tables = read_set + .tables + .iter() + .filter(|t| !group.table_deps.contains(&t.as_str())) + .cloned() + .collect(); if serialized_len > self.max_cached_result_bytes { tracing::debug!( @@ -390,6 +409,12 @@ impl SubscriptionManager { } else { group.record_execution_with_data(read_set, result_hash, data); } + tables + } else { + return; + }; + for table in new_tables { + self.table_index.entry(table).or_default().insert(group_id); } } diff --git a/crates/forge-runtime/src/realtime/message.rs b/crates/forge-runtime/src/realtime/message.rs index a5692251..f8572bd4 100644 --- a/crates/forge-runtime/src/realtime/message.rs +++ b/crates/forge-runtime/src/realtime/message.rs @@ -161,7 +161,39 @@ impl SessionServer { total_drops: AtomicU32::new(0), token_exp, }; - self.connections.insert(session_id, entry); + // Notify the displaced client before replacing it so it doesn't + // think it's still receiving live data on a dead channel. + if let Some(prev) = self.connections.insert(session_id, entry) { + let _ = prev.sender.try_send(RealtimeMessage::AuthFailed { + reason: "Session replaced by a newer connection".to_string(), + }); + } + } + + /// Notify the client that its auth has been revoked, then tear down the + /// connection. Returns the subscription IDs the session held, so callers + /// can clean up associated query/job/workflow state. The notification is + /// best-effort (`try_send`): if the channel is full or closed, eviction + /// still proceeds. + /// + /// Use this when a session's underlying principal has been demoted, + /// tenant-moved, or revoked server-side before the JWT's `exp`. The + /// reactor's cached `AuthContext` on each `QueryGroup` is only + /// re-validated on token expiry, so without an explicit revocation path + /// the session would keep receiving data under the stale scope until + /// `exp`. After this call the client must reconnect and re-subscribe + /// with a fresh token. + pub fn revoke_session( + &self, + session_id: SessionId, + reason: &str, + ) -> Option> { + if let Some(conn) = self.connections.get(&session_id) { + let _ = conn.sender.try_send(RealtimeMessage::AuthFailed { + reason: reason.to_string(), + }); + } + self.remove_connection(session_id) } /// Remove a connection. @@ -353,7 +385,17 @@ impl SessionServer { } for (session_id, _) in stale { - self.remove_connection(session_id); + // Re-check last_active under the entry guard: a concurrent + // try_send may have bumped it between the snapshot and now, + // and evicting a connection that just successfully received + // traffic would drop a healthy client. + let still_stale = self + .connections + .get(&session_id) + .is_some_and(|c| c.last_active.load(Ordering::Relaxed) < cutoff_ts); + if still_stale { + self.remove_connection(session_id); + } } } @@ -870,6 +912,44 @@ mod tests { assert!(evicted.is_empty()); } + #[tokio::test] + async fn revoke_session_notifies_then_evicts_connection_and_subscriptions() { + // Server-side auth revocation path: client gets one final AuthFailed + // message, the connection is removed, and the subscription mappings + // are returned for caller cleanup. After revocation the session must + // be unreachable — any send returns SessionNotFound, forcing the + // client to reconnect with a fresh token. + let server = SessionServer::new(NodeId::new(), RealtimeConfig::default()); + let session_id = SessionId::new(); + let sub_a = SubscriptionId::new(); + let sub_b = SubscriptionId::new(); + let (tx, mut rx) = mpsc::channel(8); + + server.register_connection(session_id, tx, None); + server.add_subscription(session_id, sub_a).unwrap(); + server.add_subscription(session_id, sub_b).unwrap(); + + let removed = server + .revoke_session(session_id, "role demoted") + .expect("session existed"); + assert_eq!(removed.len(), 2); + assert!(removed.contains(&sub_a)); + assert!(removed.contains(&sub_b)); + + match rx.recv().await { + Some(RealtimeMessage::AuthFailed { reason }) => assert_eq!(reason, "role demoted"), + other => panic!("expected AuthFailed, got {other:?}"), + } + + assert_eq!(server.connection_count(), 0); + assert_eq!(server.subscription_count(), 0); + let result = server.try_send_to_session(session_id, RealtimeMessage::Lagging); + assert!(matches!(result, Err(SendError::SessionNotFound))); + + // Calling revoke again on a gone session is a no-op. + assert!(server.revoke_session(session_id, "again").is_none()); + } + #[test] fn cleanup_expired_tokens_returns_empty_when_nothing_expired() { let server = SessionServer::new(NodeId::new(), RealtimeConfig::default()); diff --git a/crates/forge-runtime/src/realtime/reactor.rs b/crates/forge-runtime/src/realtime/reactor.rs index bd490e86..b6724a72 100644 --- a/crates/forge-runtime/src/realtime/reactor.rs +++ b/crates/forge-runtime/src/realtime/reactor.rs @@ -186,10 +186,13 @@ impl Reactor { self.session_server.remove_connection(session_id); // Clean up job subscriptions using reverse index for O(1) lookup + // Lock order: subscriptions map BEFORE session reverse-index. + // Matches the cleanup task and `unsubscribe_job`/`unsubscribe_workflow` + // to remove the deadlock window where opposite orders could meet. { + let mut job_subs = self.job_subscriptions.write().await; let job_ids = self.session_job_ids.write().await.remove(&session_id); if let Some(ids) = job_ids { - let mut job_subs = self.job_subscriptions.write().await; for id in ids { if let Some(subscribers) = job_subs.get_mut(&id) { subscribers.retain(|s| s.session_id != session_id); @@ -203,9 +206,9 @@ impl Reactor { // Clean up workflow subscriptions using reverse index for O(1) lookup { + let mut workflow_subs = self.workflow_subscriptions.write().await; let wf_ids = self.session_workflow_ids.write().await.remove(&session_id); if let Some(ids) = wf_ids { - let mut workflow_subs = self.workflow_subscriptions.write().await; for id in ids { if let Some(subscribers) = workflow_subs.get_mut(&id) { subscribers.retain(|s| s.session_id != session_id); @@ -265,7 +268,23 @@ impl Reactor { } }; - let (result_hash, serialized_len) = Self::compute_hash(&data); + // A subscription with no observable table dependencies could + // never be invalidated, so it would sit live forever and never + // re-execute. Reject up front instead of silently going dark. + if table_deps.is_empty() && read_set.tables.is_empty() { + self.unsubscribe(subscription_id); + return Err(forge_core::ForgeError::Validation(format!( + "Query '{}' has no table dependencies and cannot be subscribed to", + query_name + ))); + } + + let Some((result_hash, serialized_len)) = Self::compute_hash(&data) else { + self.unsubscribe(subscription_id); + return Err(forge_core::ForgeError::internal( + "Failed to serialize query result for change detection", + )); + }; tracing::trace!( ?group_id, @@ -349,26 +368,33 @@ impl Reactor { } /// Unsubscribe from job updates. - pub async fn unsubscribe_job(&self, session_id: SessionId, client_sub_id: &str) { - let mut subs = self.job_subscriptions.write().await; - let mut removed_ids = Vec::new(); - for (job_id, subscribers) in subs.iter_mut() { - let before = subscribers.len(); - subscribers - .retain(|s| !(s.session_id == session_id && s.client_sub_id == client_sub_id)); - if subscribers.len() < before { - removed_ids.push(*job_id); + /// + /// Requires `job_id` so this is O(subscribers for one job) instead of + /// walking every job entry. Callers always know it: it's the id they + /// passed to `subscribe_job`. + /// + /// Lock order: `job_subscriptions` -> `session_job_ids`. Matches + /// `remove_session` to prevent deadlocks under adversarial scheduling. + pub async fn unsubscribe_job(&self, session_id: SessionId, job_id: Uuid, client_sub_id: &str) { + let removed = { + let mut subs = self.job_subscriptions.write().await; + let mut removed = false; + if let Some(subscribers) = subs.get_mut(&job_id) { + let before = subscribers.len(); + subscribers + .retain(|s| !(s.session_id == session_id && s.client_sub_id == client_sub_id)); + removed = subscribers.len() < before; + if subscribers.is_empty() { + subs.remove(&job_id); + } } - } - subs.retain(|_, v| !v.is_empty()); - drop(subs); + removed + }; - if !removed_ids.is_empty() { + if removed { let mut session_jobs = self.session_job_ids.write().await; if let Some(ids) = session_jobs.get_mut(&session_id) { - for id in &removed_ids { - ids.remove(id); - } + ids.remove(&job_id); if ids.is_empty() { session_jobs.remove(&session_id); } @@ -410,27 +436,33 @@ impl Reactor { Ok(workflow_data) } - /// Unsubscribe from workflow updates. - pub async fn unsubscribe_workflow(&self, session_id: SessionId, client_sub_id: &str) { - let mut subs = self.workflow_subscriptions.write().await; - let mut removed_ids = Vec::new(); - for (wf_id, subscribers) in subs.iter_mut() { - let before = subscribers.len(); - subscribers - .retain(|s| !(s.session_id == session_id && s.client_sub_id == client_sub_id)); - if subscribers.len() < before { - removed_ids.push(*wf_id); + /// Unsubscribe from workflow updates. See [`unsubscribe_job`] for the + /// rationale on requiring the `workflow_id`. + pub async fn unsubscribe_workflow( + &self, + session_id: SessionId, + workflow_id: Uuid, + client_sub_id: &str, + ) { + let removed = { + let mut subs = self.workflow_subscriptions.write().await; + let mut removed = false; + if let Some(subscribers) = subs.get_mut(&workflow_id) { + let before = subscribers.len(); + subscribers + .retain(|s| !(s.session_id == session_id && s.client_sub_id == client_sub_id)); + removed = subscribers.len() < before; + if subscribers.is_empty() { + subs.remove(&workflow_id); + } } - } - subs.retain(|_, v| !v.is_empty()); - drop(subs); + removed + }; - if !removed_ids.is_empty() { + if removed { let mut session_wfs = self.session_workflow_ids.write().await; if let Some(ids) = session_wfs.get_mut(&session_id) { - for id in &removed_ids { - ids.remove(id); - } + ids.remove(&workflow_id); if ids.is_empty() { session_wfs.remove(&session_id); } @@ -465,14 +497,13 @@ impl Reactor { } /// Content hash for change detection; returns `(hash, byte_count)`. - fn compute_hash(data: &serde_json::Value) -> (String, usize) { - match serde_json::to_vec(data) { - Ok(bytes) => { - let len = bytes.len(); - (crate::stable_hash::sha256_hex(&bytes), len) - } - Err(_) => ("!serialization_failed!".to_string(), usize::MAX), - } + /// `None` if serialization fails — callers MUST skip the update so a + /// failure sentinel isn't cached and used to suppress later, legitimate + /// "still broken" notifications or emit spurious data on recovery. + fn compute_hash(data: &serde_json::Value) -> Option<(String, usize)> { + let bytes = serde_json::to_vec(data).ok()?; + let len = bytes.len(); + Some((crate::stable_hash::sha256_hex(&bytes), len)) } /// Flush pending invalidations with bounded concurrent re-execution. @@ -549,7 +580,59 @@ impl Reactor { .await; } + /// Drop every subscription for a session (query, job, workflow) and close + /// its SSE channel after a final `AuthFailed` notification. Intended as an + /// admin escape hatch for server-side auth revocation: cached + /// `AuthContext` on `QueryGroup` is captured at subscribe time and only + /// re-validated on JWT expiry, so demotions/tenant moves that happen + /// before `exp` cannot be detected by the reactor itself. Operators wire + /// this up to their identity system's revocation event; after the call + /// the client must reconnect and re-subscribe with a fresh token. + pub async fn revoke_session_auth(&self, session_id: SessionId, reason: &str) { + self.subscription_manager + .remove_session_subscriptions(session_id); + self.session_server.revoke_session(session_id, reason); + + { + let mut job_subs = self.job_subscriptions.write().await; + let job_ids = self.session_job_ids.write().await.remove(&session_id); + if let Some(ids) = job_ids { + for id in ids { + if let Some(subscribers) = job_subs.get_mut(&id) { + subscribers.retain(|s| s.session_id != session_id); + if subscribers.is_empty() { + job_subs.remove(&id); + } + } + } + } + } + + { + let mut workflow_subs = self.workflow_subscriptions.write().await; + let wf_ids = self.session_workflow_ids.write().await.remove(&session_id); + if let Some(ids) = wf_ids { + for id in ids { + if let Some(subscribers) = workflow_subs.get_mut(&id) { + subscribers.retain(|s| s.session_id != session_id); + if subscribers.is_empty() { + workflow_subs.remove(&id); + } + } + } + } + } + + tracing::info!(?session_id, %reason, "Session auth revoked"); + } + /// Re-run queries for groups, pushing to subscribers on hash change. + /// + /// Note: re-execution uses the `AuthContext` captured at subscribe time. + /// Authorization is checked at subscribe time and only re-checked on + /// token expiry. Server-side role/tenant changes before `exp` are not + /// detected here; callers must invoke [`revoke_session_auth`] to evict + /// affected sessions explicitly. async fn reexecute_groups( group_ids: &[forge_core::realtime::QueryGroupId], subscription_manager: &Arc, @@ -638,7 +721,13 @@ impl Reactor { }; match result { Ok((new_data, read_set)) => { - let (new_hash, serialized_len) = Self::compute_hash(&new_data); + let Some((new_hash, serialized_len)) = Self::compute_hash(&new_data) else { + tracing::warn!( + ?group_id, + "Skipping group update: result failed to serialize" + ); + continue; + }; if last_hash.as_ref() != Some(&new_hash) { let data_arc = std::sync::Arc::new(new_data); @@ -715,6 +804,7 @@ impl Reactor { tracing::debug!("Reactor listening for changes"); let mut restart_count: u32 = 0; + let mut consecutive_lags: u32 = 0; let (listener_error_tx, mut listener_error_rx) = mpsc::channel::(1); // Start initial listener @@ -757,6 +847,11 @@ impl Reactor { // again; reset so a long-lived process can absorb more // transient failures over its lifetime. restart_count = 0; + consecutive_lags = 0; + // A successful change proves the new listener + // is healthy; flush stale errors so they + // can't be misattributed to it later. + while listener_error_rx.try_recv().is_ok() {} Self::handle_change( &change, &invalidation_engine, @@ -767,10 +862,20 @@ impl Reactor { ).await; } Err(broadcast::error::RecvError::Lagged(n)) => { + // Back off exponentially on consecutive lags so a + // sustained event rate above the broadcast buffer + // doesn't pin us in a resync-storm. + let backoff_ms = 100u64 + .saturating_mul(1u64 << consecutive_lags.min(6)); + consecutive_lags = consecutive_lags.saturating_add(1); tracing::warn!( missed = n, - "Reactor lagged; scheduling full resync" + consecutive_lags, + backoff_ms, + "Reactor lagged; backing off before scheduling full resync" ); + tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)) + .await; listener.set_needs_resync(); } Err(broadcast::error::RecvError::Closed) => { @@ -889,6 +994,14 @@ impl Reactor { if let Some(handle) = listener_handle.take() { handle.abort(); } + // Drain any further error messages that piled up + // while we were sleeping the backoff: the aborted + // listener may have queued additional failures, and + // the new listener must not be debited for them + // (each stale message would otherwise bump + // restart_count toward max_restarts on phantom + // restarts and emit a false "permanently failed"). + while listener_error_rx.try_recv().is_ok() {} change_rx = listener.subscribe(); listener_handle = Some(tokio::spawn(async move { if let Err(e) = listener_clone.run(&bus_clone).await { @@ -1299,18 +1412,12 @@ impl Reactor { let mut read_set = ReadSet::new(); - if info.table_dependencies.is_empty() { - let table_name = Self::extract_table_name(query_name); - read_set.add_table(&table_name); - tracing::trace!( - query = %query_name, - fallback_table = %table_name, - "Using naming convention fallback for table dependency" - ); - } else { - for table in info.table_dependencies { - read_set.add_table(*table); - } + // No naming-convention fallback: a fake "table" equal to + // the query name would never appear in real change events, + // so the subscription would silently never re-execute. + // Callers reject empty-deps subscriptions at subscribe time. + for table in info.table_dependencies { + read_set.add_table(*table); } Ok((data, read_set)) @@ -1322,10 +1429,6 @@ impl Reactor { } } - fn extract_table_name(query_name: &str) -> String { - query_name.to_string() - } - /// Auth check for re-execution (authentication only, roles checked at subscribe time). fn check_query_auth( info: &forge_core::function::FunctionInfo, @@ -1461,9 +1564,9 @@ mod tests { let data2 = serde_json::json!({"name": "test"}); let data3 = serde_json::json!({"name": "different"}); - let (hash1, len1) = Reactor::compute_hash(&data1); - let (hash2, _) = Reactor::compute_hash(&data2); - let (hash3, _) = Reactor::compute_hash(&data3); + let (hash1, len1) = Reactor::compute_hash(&data1).expect("hash data1"); + let (hash2, _) = Reactor::compute_hash(&data2).expect("hash data2"); + let (hash3, _) = Reactor::compute_hash(&data3).expect("hash data3"); assert_eq!(hash1, hash2); assert_ne!(hash1, hash3); diff --git a/crates/forge-runtime/src/signals/bot.rs b/crates/forge-runtime/src/signals/bot.rs index 37ed45fb..912e2929 100644 --- a/crates/forge-runtime/src/signals/bot.rs +++ b/crates/forge-runtime/src/signals/bot.rs @@ -61,10 +61,14 @@ const BOT_PATTERNS: &[&str] = &[ "curl/", "libwww", "apache-httpclient", - "okhttp", - "node-fetch", - "axios", - "postman", + // Patterns below match the format these libraries actually emit in a UA + // (token + "/version"). Bare substrings like "axios" or "okhttp" flagged + // legitimate mobile apps and react-native clients that ship those names + // embedded inside larger UAs. + "okhttp/", + "node-fetch/", + "axios/", + "postmanruntime/", ]; /// Pre-compiled Aho-Corasick automaton for bot detection. diff --git a/crates/forge-runtime/src/signals/collector.rs b/crates/forge-runtime/src/signals/collector.rs index ab2df237..b8e975e4 100644 --- a/crates/forge-runtime/src/signals/collector.rs +++ b/crates/forge-runtime/src/signals/collector.rs @@ -12,6 +12,12 @@ use sqlx::PgPool; use tokio::sync::{Mutex, mpsc, oneshot}; use tracing::{debug, error, warn}; +/// Hard ceiling on the total in-buffer byte size before we force a flush. +/// Caps PG memory pressure when a single UNNEST batch would otherwise grow +/// into hundreds of MB. Tracked alongside `batch_size` (whichever fires +/// first wins). +const MAX_BUFFER_BYTES: usize = 16 * 1024 * 1024; + /// Buffered signal event collector. /// /// Clone-friendly (shares the mpsc sender). Send events from any async @@ -110,6 +116,7 @@ async fn flush_loop( mut shutdown_rx: oneshot::Receiver>, ) { let mut buffer: Vec = Vec::with_capacity(batch_size); + let mut buffer_bytes: usize = 0; let mut interval = tokio::time::interval(flush_interval); interval.tick().await; @@ -118,6 +125,7 @@ async fn flush_loop( biased; ack = &mut shutdown_rx => { while let Ok(event) = rx.try_recv() { + buffer_bytes = buffer_bytes.saturating_add(estimate_event_bytes(&event)); buffer.push(event); } if !buffer.is_empty() { @@ -132,9 +140,11 @@ async fn flush_loop( event = rx.recv() => { match event { Some(e) => { + buffer_bytes = buffer_bytes.saturating_add(estimate_event_bytes(&e)); buffer.push(e); - if buffer.len() >= batch_size { + if buffer.len() >= batch_size || buffer_bytes >= MAX_BUFFER_BYTES { flush_batch(&pool, &mut buffer).await; + buffer_bytes = 0; } } None => { @@ -149,12 +159,49 @@ async fn flush_loop( _ = interval.tick() => { if !buffer.is_empty() { flush_batch(&pool, &mut buffer).await; + buffer_bytes = 0; } } } } } +/// Cheap byte estimate dominated by the variable-size fields. Avoids +/// re-serializing the properties JSON for every accounting update — we just +/// take the length of its serde repr where it's a string/object/array, and +/// fall back to a fixed overhead. +fn estimate_event_bytes(event: &SignalEvent) -> usize { + fn opt_len(s: &Option) -> usize { + s.as_deref().map(str::len).unwrap_or(0) + } + let props = serde_json::to_vec(&event.properties) + .map(|v| v.len()) + .unwrap_or(0); + let ctx = serde_json::to_vec(&event.error_context) + .map(|v| v.len()) + .unwrap_or(0); + // 256 = fixed-size column overhead (uuids, ints, timestamps, bools). + 256 + props + + ctx + + opt_len(&event.event_name) + + opt_len(&event.correlation_id) + + opt_len(&event.visitor_id) + + opt_len(&event.page_url) + + opt_len(&event.referrer) + + opt_len(&event.function_name) + + opt_len(&event.function_kind) + + opt_len(&event.status) + + opt_len(&event.error_message) + + opt_len(&event.error_stack) + + opt_len(&event.client_ip) + + opt_len(&event.country) + + opt_len(&event.city) + + opt_len(&event.user_agent) + + opt_len(&event.device_type) + + opt_len(&event.browser) + + opt_len(&event.os) +} + /// Flush a batch of events into PostgreSQL using UNNEST for single-roundtrip INSERT. /// Uses runtime sqlx::query() because UNNEST with typed arrays is not supported by /// the compile-time sqlx::query!() macro. diff --git a/crates/forge-runtime/src/signals/endpoints.rs b/crates/forge-runtime/src/signals/endpoints.rs index e497181d..a49502a7 100644 --- a/crates/forge-runtime/src/signals/endpoints.rs +++ b/crates/forge-runtime/src/signals/endpoints.rs @@ -32,6 +32,15 @@ use super::visitor; /// Maximum events per batch request. const MAX_BATCH_SIZE: usize = 50; +/// Maximum serialized byte size of a single event's free-form `properties` +/// JSON. Larger payloads are rejected. Prevents apps from dumping request +/// bodies / PII into analytics rows. +const MAX_PROPERTY_BYTES: usize = 4096; + +/// Maximum serialized byte size of a single event envelope (event name + +/// properties + correlation_id). Larger batch entries are rejected. +const MAX_EVENT_BYTES: usize = 8192; + /// Check the client's Do-Not-Track header. We honor DNT: 1 by short-circuiting /// signal ingestion -- the browser has explicitly opted out of tracking. /// Sec-GPC (Global Privacy Control) is also respected. @@ -125,6 +134,14 @@ async fn handle_event( if batch.events.len() > MAX_BATCH_SIZE { return rate_limited_response(); } + for event in &batch.events { + if !event_within_limits(event) { + return Json(SignalResponse { + ok: false, + session_id: None, + }); + } + } let ctx = extract_request_ctx( headers, @@ -133,28 +150,22 @@ async fn handle_event( &state.server_secret, state.anonymize_ip, state.geoip.as_ref(), - ); - let session_id = + ) + .await; + let supplied_session_id = resolve_session_id(batch.context.as_ref().and_then(|c| c.session_id.as_deref())); + let session_id = Some(supplied_session_id.unwrap_or_else(Uuid::new_v4)); let page_url = batch.context.as_ref().and_then(|c| c.page_url.clone()); - let session_id = session::upsert_session( - &state.pool, + let referrer = batch.context.as_ref().and_then(|c| c.referrer.clone()); + spawn_session_upsert( + state.pool.clone(), session_id, - &ctx.visitor_id, - ctx.user_id, - ctx.tenant_id, - page_url.as_deref(), - batch.context.as_ref().and_then(|c| c.referrer.as_deref()), - ctx.user_agent.as_deref(), - ctx.client_ip.as_deref(), - ctx.is_bot, + &ctx, + page_url.clone(), + referrer, "track", - ctx.device_type.as_deref(), - ctx.browser.as_deref(), - ctx.os.as_deref(), - ) - .await; + ); for event in batch.events { let signal = SignalEvent { @@ -216,27 +227,20 @@ async fn handle_view( &state.server_secret, state.anonymize_ip, state.geoip.as_ref(), - ); + ) + .await; let session_id_header = extract_header(headers, "x-session-id"); - let session_id = resolve_session_id(session_id_header.as_deref()); + let supplied_session_id = resolve_session_id(session_id_header.as_deref()); + let session_id = Some(supplied_session_id.unwrap_or_else(Uuid::new_v4)); - let session_id = session::upsert_session( - &state.pool, + spawn_session_upsert( + state.pool.clone(), session_id, - &ctx.visitor_id, - ctx.user_id, - ctx.tenant_id, - Some(&payload.url), - payload.referrer.as_deref(), - ctx.user_agent.as_deref(), - ctx.client_ip.as_deref(), - ctx.is_bot, + &ctx, + Some(payload.url.clone()), + payload.referrer.clone(), "page_view", - ctx.device_type.as_deref(), - ctx.browser.as_deref(), - ctx.os.as_deref(), - ) - .await; + ); let utm = if payload.utm_source.is_some() || payload.utm_medium.is_some() @@ -314,28 +318,13 @@ async fn handle_report( &state.server_secret, state.anonymize_ip, state.geoip.as_ref(), - ); + ) + .await; let session_id_header = extract_header(headers, "x-session-id"); let session_id = resolve_session_id(session_id_header.as_deref()); - if let Some(sid) = session_id { - session::upsert_session( - &state.pool, - Some(sid), - &ctx.visitor_id, - ctx.user_id, - ctx.tenant_id, - None, - None, - ctx.user_agent.as_deref(), - ctx.client_ip.as_deref(), - ctx.is_bot, - "error", - ctx.device_type.as_deref(), - ctx.browser.as_deref(), - ctx.os.as_deref(), - ) - .await; + if session_id.is_some() { + spawn_session_upsert(state.pool.clone(), session_id, &ctx, None, None, "error"); } for err in report.errors { @@ -392,7 +381,7 @@ struct RequestCtx { os: Option, } -fn extract_request_ctx( +async fn extract_request_ctx( headers: &HeaderMap, resolved_ip: Option, auth: &Option>, @@ -413,12 +402,26 @@ fn extract_request_ctx( let user_id = auth.as_ref().and_then(|a| a.user_id()); let tenant_id = auth.as_ref().and_then(|a| a.tenant_id()); let device_info = device::parse_lowered(platform_header.as_deref(), &ua_lower); - let geo = geoip - .zip(raw_ip.as_deref()) - .map(|(g, ip)| g.lookup(ip)) - .unwrap_or_default(); + let geo = match (geoip, raw_ip.clone()) { + (Some(g), Some(ip)) => { + // MMDB lookups can be CPU-blocking on cold pages; offload so the + // request thread keeps feeding the collector. + let g = g.clone(); + tokio::task::spawn_blocking(move || g.lookup(&ip)) + .await + .unwrap_or_default() + } + _ => super::geoip::GeoInfo::default(), + }; // anonymize_ip drops the raw IP after visitor_id + geo are derived; GDPR-friendly default. let client_ip = if anonymize_ip { None } else { raw_ip }; + // When IP is anonymized, also strip the UA major-version so the combo of + // UA + country + city can't be used to re-fingerprint the visitor. + let user_agent = if anonymize_ip { + user_agent.as_deref().map(anonymize_ua) + } else { + user_agent + }; RequestCtx { user_agent, client_ip, @@ -438,6 +441,76 @@ fn extract_header(headers: &HeaderMap, name: &str) -> Option { crate::gateway::extract_header(headers, name) } +/// Strip the major version off a UA so a per-version identifier can't be +/// derived. Recognizes the most common browser family tokens; falls back to +/// the broad family when the UA doesn't match any known prefix. +fn anonymize_ua(ua: &str) -> String { + const FAMILIES: &[&str] = &["Chrome/", "Firefox/", "Safari/", "Edg/", "Opera/"]; + for family in FAMILIES { + if ua.contains(family) { + return (*family).to_string(); + } + } + "Other".to_string() +} + +/// Per-event size guard. Rejects events whose serialized properties / event +/// envelope exceed configured limits. +fn event_within_limits(event: &forge_core::signals::ClientEvent) -> bool { + let props_bytes = match serde_json::to_vec(&event.properties) { + Ok(b) => b.len(), + Err(_) => return false, + }; + if props_bytes > MAX_PROPERTY_BYTES { + return false; + } + let total = event.event.len() + + props_bytes + + event.correlation_id.as_deref().map(str::len).unwrap_or(0); + total <= MAX_EVENT_BYTES +} + +/// Fire-and-forget the session upsert so the request thread doesn't block on +/// a PG round-trip. We mint the session ID synchronously upstream so the +/// response can return it before the row is persisted. +fn spawn_session_upsert( + pool: PgPool, + session_id: Option, + ctx: &RequestCtx, + page_url: Option, + referrer: Option, + event_type: &'static str, +) { + let visitor_id = ctx.visitor_id.clone(); + let user_id = ctx.user_id; + let tenant_id = ctx.tenant_id; + let user_agent = ctx.user_agent.clone(); + let client_ip = ctx.client_ip.clone(); + let device_type = ctx.device_type.clone(); + let browser = ctx.browser.clone(); + let os = ctx.os.clone(); + let is_bot = ctx.is_bot; + tokio::spawn(async move { + session::upsert_session( + &pool, + session_id, + &visitor_id, + user_id, + tenant_id, + page_url.as_deref(), + referrer.as_deref(), + user_agent.as_deref(), + client_ip.as_deref(), + is_bot, + event_type, + device_type.as_deref(), + browser.as_deref(), + os.as_deref(), + ) + .await; + }); +} + fn resolve_session_id(raw: Option<&str>) -> Option { raw.and_then(|s| Uuid::parse_str(s).ok()) } diff --git a/crates/forge-runtime/src/signals/partition.rs b/crates/forge-runtime/src/signals/partition.rs index 0e09fc33..bb794568 100644 --- a/crates/forge-runtime/src/signals/partition.rs +++ b/crates/forge-runtime/src/signals/partition.rs @@ -2,6 +2,27 @@ //! //! Creates partitions for upcoming months and drops partitions //! older than the configured retention period. +//! +//! ## Operational expectation +//! +//! Pre-creation runs every maintenance tick and covers the current month plus +//! the next three months. Anything farther out lands in the catch-all +//! `forge_signals_events_default` partition, which is **excluded from the +//! retention sweep** — rows there accumulate forever and won't be dropped. +//! +//! Two failure modes to watch: +//! +//! 1. A node hibernates / loses its scheduler past the +3-month horizon. When +//! it wakes back up, any inserts whose `timestamp` falls outside the +//! rolling window land in the default partition until the maintenance loop +//! next runs. +//! 2. A client sends events with `timestamp` far in the future (clock skew, +//! backfill jobs). Same outcome. +//! +//! `check_default_partition` logs an error whenever rows are present in the +//! default partition. Treat that log as actionable: investigate the coverage +//! gap, then move the misrouted rows into the correct month partition by +//! hand (or accept that they'll never be cleaned up by retention). // Partition DDL constructs table names from runtime dates, so the query macros // can't validate them at compile time. diff --git a/crates/forge-runtime/src/signals/rate_limit.rs b/crates/forge-runtime/src/signals/rate_limit.rs index b2ddaa6b..af78dbda 100644 --- a/crates/forge-runtime/src/signals/rate_limit.rs +++ b/crates/forge-runtime/src/signals/rate_limit.rs @@ -8,6 +8,8 @@ //! limit is effectively `nodes * max_per_window`, which is fine for abuse //! protection — billing-grade limits are not the goal here. +use std::collections::VecDeque; +use std::sync::Mutex; use std::sync::atomic::{AtomicI64, AtomicU32, Ordering}; use dashmap::DashMap; @@ -16,6 +18,10 @@ use dashmap::DashMap; /// minute. Generous enough to absorb legitimate bursts (page-view + web-vital /// flush + a handful of tracked events on a navigation) while still capping /// runaway clients. +/// +/// TODO(signals-config): expose `SignalsConfig::rate_limit_per_minute` and +/// thread it through `gateway::server` so operators can tune this in +/// forge.toml. Until then, callers can override via `with_limit(...)`. const DEFAULT_MAX_REQUESTS_PER_WINDOW: u32 = 600; /// Window length in seconds. @@ -29,6 +35,11 @@ const MAX_TRACKED_IPS: usize = 100_000; pub struct SignalRateLimiter { max_per_window: u32, buckets: DashMap, + /// Insertion-order queue. When `buckets.len()` exceeds `MAX_TRACKED_IPS` + /// we pop the oldest entry off the front and remove it from the map. + /// O(1) amortized — replaces the previous O(n) `evict_oldest` sweep that + /// ran inline on every new-IP miss. + insertion_order: Mutex>, } struct IpBucket { @@ -47,6 +58,7 @@ impl SignalRateLimiter { Self { max_per_window, buckets: DashMap::new(), + insertion_order: Mutex::new(VecDeque::new()), } } @@ -68,14 +80,30 @@ impl SignalRateLimiter { bucket.count.store(1, Ordering::Relaxed); return true; } + // Note: fetch_add followed by the prev < max comparison is a + // benign race — two concurrent callers can both observe a value + // just under the ceiling and both succeed, leaving the bucket a + // small constant over the configured max. Acceptable for abuse + // protection; billing-grade limits are explicitly out of scope. let prev = bucket.count.fetch_add(1, Ordering::Relaxed); return prev < self.max_per_window; } - if self.buckets.len() >= MAX_TRACKED_IPS { - self.evict_oldest(); - } + self.insert_new_bucket(ip, now); + true + } + /// Insert a freshly-seen IP. Evicts the oldest tracked IP in O(1) when the + /// map is at capacity (FIFO; not strictly LRU but good enough — abuse- + /// driven floods churn the queue fast enough that any stale entries fall + /// off naturally). + fn insert_new_bucket(&self, ip: &str, now: i64) { + if self.buckets.len() >= MAX_TRACKED_IPS + && let Ok(mut order) = self.insertion_order.lock() + && let Some(victim) = order.pop_front() + { + self.buckets.remove(&victim); + } self.buckets.insert( ip.to_string(), IpBucket { @@ -83,13 +111,9 @@ impl SignalRateLimiter { count: AtomicU32::new(1), }, ); - true - } - - fn evict_oldest(&self) { - let cutoff = chrono::Utc::now().timestamp() - WINDOW_SECS; - self.buckets - .retain(|_, bucket| bucket.window_start.load(Ordering::Relaxed) >= cutoff); + if let Ok(mut order) = self.insertion_order.lock() { + order.push_back(ip.to_string()); + } } } diff --git a/crates/forge-runtime/src/signals/session.rs b/crates/forge-runtime/src/signals/session.rs index d049dffd..3a3b535f 100644 --- a/crates/forge-runtime/src/signals/session.rs +++ b/crates/forge-runtime/src/signals/session.rs @@ -69,7 +69,12 @@ pub async fn upsert_session( } } - let new_id = Uuid::new_v4(); + // Reuse the caller-supplied session id when present: the handler already + // returned it to the client, so the persisted row MUST carry that same id. + // Minting a fresh UUID here orphaned the row under an id the client never + // saw, so every later event re-missed the UPDATE and spawned a new session — + // breaking session continuity. Only generate when no id was supplied. + let new_id = session_id.unwrap_or_else(Uuid::new_v4); let referrer_domain = referrer.and_then(extract_domain); let result = sqlx::query( diff --git a/crates/forge-runtime/src/signals/tests.rs b/crates/forge-runtime/src/signals/tests.rs index 91bdfa32..9d5715d0 100644 --- a/crates/forge-runtime/src/signals/tests.rs +++ b/crates/forge-runtime/src/signals/tests.rs @@ -732,3 +732,103 @@ async fn test_partition_ensure() { db.cleanup().await.unwrap(); } + +// ── Privacy short-circuit ─────────────────────────────────────────────────── + +/// `DNT: 1` opts the visitor out of analytics. The endpoint must return +/// `ok: true` (so the client doesn't keep retrying) without persisting any +/// event row. A regression that drops the short-circuit would silently +/// re-enable tracking for opted-out users. +#[tokio::test] +async fn dnt_header_short_circuits_event_storage() { + let db = setup("dnt_short_circuit").await; + let state = make_signals_state(db.pool()); + + let batch = SignalEventBatch { + events: vec![ClientEvent { + event: "dnt_should_not_persist".to_string(), + properties: serde_json::json!({}), + correlation_id: None, + timestamp: None, + }], + context: None, + }; + + let mut headers = make_headers(); + headers.insert("dnt", HeaderValue::from_static("1")); + + let response = endpoints::signal_handler( + State(state.clone()), + None, + None, + headers, + Json(SignalPayload::Event(batch)), + ) + .await + .into_response(); + + let body: serde_json::Value = axum::body::to_bytes(response.into_body(), 1024) + .await + .map(|b| serde_json::from_slice(&b).unwrap()) + .unwrap(); + assert_eq!( + body["ok"], true, + "DNT must return ok so clients stop retrying" + ); + + // Give the collector a chance to flush — if anything was queued it would + // land in this window. We assert nothing was stored. + tokio::time::sleep(Duration::from_millis(200)).await; + + let count: (i64,) = sqlx::query_as( + "SELECT COUNT(*) FROM forge_signals_events WHERE event_name = 'dnt_should_not_persist'", + ) + .fetch_one(db.pool()) + .await + .unwrap(); + assert_eq!(count.0, 0, "DNT events must not be persisted"); + + db.cleanup().await.unwrap(); +} + +/// `Sec-GPC: 1` (Global Privacy Control) is the modern equivalent of DNT +/// and must short-circuit the same way. +#[tokio::test] +async fn sec_gpc_header_short_circuits_event_storage() { + let db = setup("gpc_short_circuit").await; + let state = make_signals_state(db.pool()); + + let batch = SignalEventBatch { + events: vec![ClientEvent { + event: "gpc_should_not_persist".to_string(), + properties: serde_json::json!({}), + correlation_id: None, + timestamp: None, + }], + context: None, + }; + + let mut headers = make_headers(); + headers.insert("sec-gpc", HeaderValue::from_static("1")); + + let _ = endpoints::signal_handler( + State(state.clone()), + None, + None, + headers, + Json(SignalPayload::Event(batch)), + ) + .await; + + tokio::time::sleep(Duration::from_millis(200)).await; + + let count: (i64,) = sqlx::query_as( + "SELECT COUNT(*) FROM forge_signals_events WHERE event_name = 'gpc_should_not_persist'", + ) + .fetch_one(db.pool()) + .await + .unwrap(); + assert_eq!(count.0, 0, "Sec-GPC events must not be persisted"); + + db.cleanup().await.unwrap(); +} diff --git a/crates/forge-runtime/src/signals/visitor.rs b/crates/forge-runtime/src/signals/visitor.rs index 0f63bedc..f1896549 100644 --- a/crates/forge-runtime/src/signals/visitor.rs +++ b/crates/forge-runtime/src/signals/visitor.rs @@ -8,6 +8,17 @@ use sha2::{Digest, Sha256}; use std::sync::RwLock; +use std::sync::atomic::{AtomicBool, Ordering}; + +/// Must stay in sync with `gateway::server::DEFAULT_SIGNAL_SECRET`. If a +/// caller passes this literal we refuse to emit a real visitor ID, since +/// the daily salt would be trivially reversible by anyone who reads the +/// open-source repo. +const DEFAULT_SIGNAL_SECRET: &str = "forge-default-signal-secret"; + +/// One-shot guard so we only log the "default secret in use" warning once +/// rather than on every request. +static WARNED_DEFAULT_SECRET: AtomicBool = AtomicBool::new(false); /// Cached daily salt to avoid recomputing on every request. struct DailySalt { @@ -29,6 +40,15 @@ pub fn generate_visitor_id( user_agent: Option<&str>, server_secret: &str, ) -> String { + if server_secret == DEFAULT_SIGNAL_SECRET { + if !WARNED_DEFAULT_SECRET.swap(true, Ordering::Relaxed) { + tracing::error!( + "signals: default visitor-ID secret in use; refusing to emit a real visitor ID. \ + Configure [auth] jwt_secret in forge.toml to enable visitor tracking." + ); + } + return String::new(); + } let ip = client_ip.unwrap_or("unknown"); let ua = user_agent.unwrap_or("unknown"); let salt = get_daily_salt(server_secret); diff --git a/crates/forge-runtime/src/webhook/handler.rs b/crates/forge-runtime/src/webhook/handler.rs index 4855966e..10276c47 100644 --- a/crates/forge-runtime/src/webhook/handler.rs +++ b/crates/forge-runtime/src/webhook/handler.rs @@ -13,6 +13,7 @@ use axum::{ use base64::{Engine as _, engine::general_purpose}; use forge_core::CircuitBreakerClient; use forge_core::function::{JobDispatch, KvHandle, WorkflowDispatch}; +use forge_core::rate_limit::{RateLimitConfig, RateLimitKey}; use forge_core::webhook::{ IdempotencySource, REPLAY_TIMESTAMP_HEADER, SignatureAlgorithm, WebhookContext, }; @@ -21,11 +22,30 @@ use ring::signature::{self, UnparsedPublicKey}; use serde_json::{Value, json}; use sha2::Sha256; use sqlx::PgPool; +use std::time::Duration; use tracing::{error, info, warn}; use uuid::Uuid; use super::registry::WebhookRegistry; -use crate::gateway::RpcError; +use crate::gateway::{ResolvedClientIp, RpcError}; +use crate::rate_limit::HybridRateLimiter; + +/// Hard cap on the inbound webhook body, also bounding the bytes persisted to +/// `forge_webhook_events.raw_body`. Without this cap a misbehaving sender can +/// fill the events table with multi-MB payloads, and unsigned webhooks have no +/// other guard at all. +const MAX_WEBHOOK_BODY_BYTES: usize = 1024 * 1024; + +/// Cap the number of comma-separated rotation secrets we will HMAC per request +/// to bound the work an attacker spraying invalid signatures can force. +const MAX_WEBHOOK_SECRETS: usize = 4; + +/// Default cap on unsigned webhook deliveries per source IP per minute. +/// +/// `allow_unsigned = true` opts out of signature validation; without flow +/// control any caller reaching the URL can spray dispatches and pollute the +/// idempotency table. The DDoS cost of unsigned endpoints is bounded here. +const UNSIGNED_RATE_LIMIT_PER_MINUTE: u32 = 60; /// State for webhook handler. #[derive(Clone)] @@ -36,10 +56,12 @@ pub struct WebhookState { job_dispatcher: Option>, workflow_dispatcher: Option>, kv: Option>, + unsigned_rate_limiter: Arc, } impl WebhookState { pub fn new(registry: Arc, pool: PgPool) -> Self { + let unsigned_rate_limiter = Arc::new(HybridRateLimiter::new(pool.clone())); Self { registry, pool, @@ -47,6 +69,7 @@ impl WebhookState { job_dispatcher: None, workflow_dispatcher: None, kv: None, + unsigned_rate_limiter, } } @@ -70,12 +93,29 @@ impl WebhookState { pub async fn webhook_handler( State(state): State>, Path(path): Path, + axum::extract::Extension(client_ip): axum::extract::Extension, headers: HeaderMap, body: Bytes, ) -> Response { let full_path = format!("/webhooks/{}", path); let request_id = Uuid::new_v4().to_string(); + if body.len() > MAX_WEBHOOK_BODY_BYTES { + warn!( + path = %full_path, + body_size = body.len(), + "Webhook body exceeds maximum size" + ); + return ( + StatusCode::PAYLOAD_TOO_LARGE, + Json(RpcError::new( + "PAYLOAD_TOO_LARGE", + "Webhook payload exceeds maximum size", + )), + ) + .into_response(); + } + let entry = match state.registry.get_by_path(&full_path) { Some(e) => e, None => { @@ -108,6 +148,50 @@ pub async fn webhook_handler( .into_response(); } + // Flow-control for unsigned webhooks: signature validation is what bounds + // the cost of an attacker spraying requests against the dispatch + idempotency + // path. When `allow_unsigned = true` we still need a per-IP ceiling to keep + // the endpoint from being a free amplification vector. + if info.signature.is_none() && info.allow_unsigned { + let ip_key = client_ip.0.as_deref().unwrap_or("unknown").to_string(); + let bucket = format!("webhook:unsigned:{}:{}", info.name, ip_key); + let config = RateLimitConfig::new(UNSIGNED_RATE_LIMIT_PER_MINUTE, Duration::from_secs(60)) + .with_key(RateLimitKey::Ip); + match state.unsigned_rate_limiter.check(&bucket, &config).await { + Ok(result) if !result.allowed => { + let retry_after = result + .retry_after + .unwrap_or(Duration::from_secs(1)) + .as_secs() + .max(1); + warn!( + webhook = info.name, + ip = %ip_key, + "Unsigned webhook rate-limited" + ); + let mut resp = ( + StatusCode::TOO_MANY_REQUESTS, + Json(RpcError::new( + "RATE_LIMITED", + "Too many unsigned webhook deliveries from this client", + )), + ) + .into_response(); + if let Ok(val) = axum::http::HeaderValue::from_str(&retry_after.to_string()) { + resp.headers_mut().insert("Retry-After", val); + } + return resp; + } + Ok(_) => {} + Err(e) => { + // Failing closed on the rate limit would make a transient PG + // hiccup take the webhook endpoint down; failing open keeps the + // already-cheap unsigned path serving, with a loud log. + warn!(webhook = info.name, error = %e, "Unsigned webhook rate-limit check failed; allowing"); + } + } + } + if let Some(ref sig_config) = info.signature { let signature = match headers .get(sig_config.header_name) @@ -141,10 +225,14 @@ pub async fn webhook_handler( } }; + // Cap secrets considered per request so an attacker spraying invalid + // signatures can't force unbounded HMACs when many rotation secrets + // are configured. let secrets: Vec<&str> = secrets_raw .split(',') .map(str::trim) .filter(|s| !s.is_empty()) + .take(MAX_WEBHOOK_SECRETS) .collect(); let signature_valid = secrets.iter().any(|secret| { validate_signature( @@ -167,7 +255,7 @@ pub async fn webhook_handler( } let idempotency_key = if let Some(ref idem_config) = info.idempotency { - match &idem_config.source { + let extracted = match &idem_config.source { IdempotencySource::Header(header_name) => headers .get(*header_name) .and_then(|v| v.to_str().ok()) @@ -181,7 +269,23 @@ pub async fn webhook_handler( } // Future IdempotencySource variants: skip key extraction. _ => None, + }; + // Idempotency was opted into; missing/malformed keys must fail closed + // rather than silently running the handler without replay protection. + if extracted.is_none() { + warn!( + webhook = info.name, + "Idempotency configured but key could not be extracted" + ); + return ( + StatusCode::BAD_REQUEST, + Json(RpcError::validation( + "Required idempotency key is missing or malformed", + )), + ) + .into_response(); } + extracted } else { None }; @@ -595,10 +699,15 @@ fn validate_stripe_webhooks( Ok(n) => n, Err(_) => return false, }; - if replay_window_secs > 0 - && (chrono::Utc::now().timestamp() - ts).unsigned_abs() > replay_window_secs - { - return false; + if replay_window_secs > 0 { + let now = chrono::Utc::now().timestamp(); + let window = i64::try_from(replay_window_secs).unwrap_or(i64::MAX); + let age = now.saturating_sub(ts); + // Reject future timestamps (age < 0) and stale ones uniformly with the + // generic replay window check. + if !(0..=window).contains(&age) { + return false; + } } let mut signed = Vec::with_capacity(timestamp.len() + 1 + body.len()); diff --git a/crates/forge-runtime/src/workflow/bridge.rs b/crates/forge-runtime/src/workflow/bridge.rs index 3afe9bbe..f9e16000 100644 --- a/crates/forge-runtime/src/workflow/bridge.rs +++ b/crates/forge-runtime/src/workflow/bridge.rs @@ -37,6 +37,28 @@ pub fn register_workflow_bridge(executor: Arc, job_registry: & let cancel: bool = args.get("cancel").and_then(Value::as_bool).unwrap_or(false); if cancel { + // Only the scheduler (`enqueue_cancel`) sets `resume_reason == + // "cancel"`. An external caller dispatching `$workflow_resume` + // via `JobDispatcher::dispatch_by_name` with `{cancel: true}` + // would NOT set this marker, so we reject — defense in depth + // against a compromised internal caller cancelling arbitrary + // runs (#11 in issues doc). The `$` prefix on the job_type is + // convention; this guard is enforcement. + let scheduler_marker = args + .get("resume_reason") + .and_then(Value::as_str) + .map(|s| s == "cancel") + .unwrap_or(false); + if !scheduler_marker { + tracing::error!( + workflow_run_id = %run_id, + "Rejected $workflow_resume cancel without scheduler marker; \ + only WorkflowScheduler may dispatch cancel jobs" + ); + return Err(forge_core::ForgeError::Forbidden( + "cancel jobs may only be dispatched by the workflow scheduler".to_string(), + )); + } let reason = args .get("reason") .and_then(Value::as_str) diff --git a/crates/forge-runtime/src/workflow/executor.rs b/crates/forge-runtime/src/workflow/executor.rs index c985796e..3f52e761 100644 --- a/crates/forge-runtime/src/workflow/executor.rs +++ b/crates/forge-runtime/src/workflow/executor.rs @@ -3,7 +3,8 @@ use std::future::Future; use std::pin::Pin; use std::sync::Arc; -use tokio::sync::RwLock; +use dashmap::DashMap; +use tokio::sync::{Mutex, RwLock}; use uuid::Uuid; use super::bridge::WORKFLOW_RESUME_JOB; @@ -50,6 +51,13 @@ pub struct WorkflowExecutor { job_queue: JobQueue, http_client: CircuitBreakerClient, compensation_state: Arc>>, + /// Per-run serialization: execute and cancel of the same run_id never + /// overlap. Without this guard a cancel landing concurrently with a + /// resume would yank the live `CompensationState` and double-fire + /// compensation while the handler is still mid-step (#3 in issues doc). + /// Entry is removed by the holder when the workflow reaches a terminal + /// state; in-flight holders elsewhere keep their Arc alive. + run_locks: Arc>>>, kv: Option>, } @@ -66,10 +74,21 @@ impl WorkflowExecutor { job_queue, http_client, compensation_state: Arc::new(RwLock::new(HashMap::new())), + run_locks: Arc::new(DashMap::new()), kv: None, } } + /// Returns an `Arc>` keyed by `run_id`, creating one if absent. + /// Holders take the mutex around any code that touches `compensation_state` + /// or runs the workflow handler to keep cancel and execute serialized. + fn run_lock(&self, run_id: Uuid) -> Arc> { + self.run_locks + .entry(run_id) + .or_insert_with(|| Arc::new(Mutex::new(()))) + .clone() + } + pub fn with_kv(mut self, kv: Arc) -> Self { self.kv = Some(kv); self @@ -137,6 +156,12 @@ impl WorkflowExecutor { .await .map_err(forge_core::ForgeError::Database)?; + // #15: A DB trigger on `forge_jobs` (`v001_initial.sql`) already + // PERFORM pg_notify('forge_jobs_available', ...) on every insert. + // PostgreSQL buffers NOTIFY in the source transaction and delivers on + // commit, so the resume job is visible to workers as soon as the row + // is. No extra NOTIFY needed here. + Ok(run_id) } @@ -148,6 +173,10 @@ impl WorkflowExecutor { resume: Option, owner_subject: Option, ) -> forge_core::Result { + // Serialize against concurrent cancel for the same run. + let lock = self.run_lock(run_id); + let _guard = lock.lock().await; + self.claim_for_execution(run_id).await?; let signal_label = if resume.is_some() { @@ -219,7 +248,16 @@ impl WorkflowExecutor { let handler = entry.handler.clone(); let exec_start = std::time::Instant::now(); - let result = tokio::time::timeout(entry.info.timeout, handler(&ctx, input)).await; + // PER-RESUME timeout. `entry.info.timeout` bounds a single resume + // call, not the whole workflow run. A workflow that sleeps for an + // hour and then runs for 4m59s under a 5m timeout will pass, even + // if it suspends and resumes many times — total wall-clock is + // unbounded by this guard (#4 in issues doc). Tracking total-run + // budget requires a new column on `forge_workflow_runs`; until + // then the field name is intentionally treated as `step_timeout` + // semantics by callers. + let step_timeout = entry.info.timeout; + let result = tokio::time::timeout(step_timeout, handler(&ctx, input)).await; let exec_duration_ms = exec_start.elapsed().as_millis().min(i32::MAX as u128) as i32; let comp = CompensationState { @@ -377,22 +415,48 @@ impl WorkflowExecutor { /// operators know manual remediation is required. This is an honest limitation: /// in-memory closures cannot survive restarts. pub async fn cancel(&self, run_id: Uuid, reason: &str) -> forge_core::Result<()> { + // Serialize against a concurrent resume/execute for this run. Without + // the lock the live handler could be mid-step while we yank its + // compensation state and flip the row to `failed` (#3 in issues doc). + let lock = self.run_lock(run_id); + let _guard = lock.lock().await; + if let Some(state) = self.compensation_state.write().await.remove(&run_id) { - self.run_compensation(run_id, &state).await?; - let error = format!("cancelled: {reason}"); - self.fail_workflow(run_id, &error).await?; + let comp_failures = self.run_compensation(run_id, &state).await?; + // #17/#18: persist the cancel reason in a dedicated column and + // surface compensation failures as a structured summary in `error` + // so operators don't have to grep logs to learn which steps need + // manual remediation. + let comp_summary = if comp_failures.is_empty() { + None + } else { + Some(format!( + "{} compensation(s) failed: {}", + comp_failures.len(), + comp_failures.join("; ") + )) + }; + self.finalize_cancel(run_id, reason, comp_summary.as_deref()) + .await?; } else { tracing::error!( workflow_run_id = %run_id, "Compensation handlers lost (process restarted since workflow began); \ manual remediation required for any side effects from completed steps" ); - let error = format!( - "cancelled: {reason} (compensation skipped: handlers lost on restart, manual remediation required)" - ); - self.fail_workflow(run_id, &error).await?; + self.finalize_cancel( + run_id, + reason, + Some("compensation skipped: handlers lost on restart, manual remediation required"), + ) + .await?; } + // Run is terminal; drop the per-run mutex entry so the map doesn't + // accumulate. Holders that captured the Arc before this point keep + // their own reference alive. + self.run_locks.remove(&run_id); + Ok(()) } @@ -411,8 +475,20 @@ impl WorkflowExecutor { /// /// Returns `false` if the run is already in a terminal state or no row /// matched. - pub async fn request_cancel(&self, run_id: Uuid, reason: &str) -> forge_core::Result { - let result = sqlx::query!( + pub async fn request_cancel( + &self, + run_id: Uuid, + reason: &str, + caller_subject: Option<&str>, + ) -> forge_core::Result { + // Ownership check parallels `JobQueue::request_cancel`: if the run has + // an `owner_subject`, the caller must match (or be `None`, which means + // system / internal). Without this any caller holding the dispatcher + // could cancel any workflow run by ID (#10 in issues doc). + // + // Runtime query — adds a parameter; avoids invalidating .sqlx/. + #[allow(clippy::disallowed_methods)] + let result = sqlx::query( r#" UPDATE forge_workflow_runs SET cancel_requested_at = NOW(), @@ -420,10 +496,16 @@ impl WorkflowExecutor { WHERE id = $1 AND status IN ('pending', 'running', 'sleeping', 'waiting') AND cancel_requested_at IS NULL + AND ( + owner_subject IS NULL + OR $3::text IS NULL + OR owner_subject = $3::text + ) "#, - run_id, - reason, ) + .bind(run_id) + .bind(reason) + .bind(caller_subject) .execute(&self.pool) .await .map_err(forge_core::ForgeError::Database)?; @@ -435,8 +517,9 @@ impl WorkflowExecutor { &self, run_id: Uuid, state: &CompensationState, - ) -> forge_core::Result<()> { + ) -> forge_core::Result> { let steps = self.get_workflow_steps(run_id).await?; + let mut failures: Vec = Vec::new(); for step_name in state.completed_steps.iter().rev() { if let Some(handler) = state.handlers.get(step_name) { @@ -457,12 +540,24 @@ impl WorkflowExecutor { .await?; } Err(e) => { + let err_str = e.to_string(); tracing::error!( workflow_run_id = %run_id, step = %step_name, - error = %e, + error = %err_str, "Compensation failed" ); + // Tag the step row so operators can see exactly which + // compensations failed and need manual remediation + // (#18 in issues doc). + self.update_step_status_with_error( + run_id, + step_name, + StepStatus::CompensationFailed, + Some(&err_str), + ) + .await?; + failures.push(format!("{step_name}: {err_str}")); } } } else { @@ -470,7 +565,7 @@ impl WorkflowExecutor { .await?; } } - Ok(()) + Ok(failures) } async fn get_workflow_steps( @@ -535,6 +630,38 @@ impl WorkflowExecutor { Ok(()) } + /// Update step status and optionally write an error message. Used for + /// surfacing `CompensationFailed` so operators can locate stuck side + /// effects without diving into logs. + async fn update_step_status_with_error( + &self, + workflow_run_id: Uuid, + step_name: &str, + status: StepStatus, + error: Option<&str>, + ) -> forge_core::Result<()> { + // forge_workflow_steps is a runtime-owned system table; offline .sqlx + // cache doesn't always include it. + #[allow(clippy::disallowed_methods)] + sqlx::query( + r#" + UPDATE forge_workflow_steps + SET status = $3, + error = COALESCE($4, error) + WHERE workflow_run_id = $1 AND step_name = $2 + "#, + ) + .bind(workflow_run_id) + .bind(step_name) + .bind(status.as_str()) + .bind(error) + .execute(&self.pool) + .await + .map_err(forge_core::ForgeError::Database)?; + + Ok(()) + } + // forge_workflow_runs is a runtime-owned system table; offline .sqlx cache // doesn't always include it. #[allow(clippy::disallowed_methods)] @@ -660,17 +787,24 @@ impl WorkflowExecutor { /// Atomically claim a workflow for execution (transition to Running). /// - /// `'running'` is included so resume picks up a run that the scheduler has - /// already flipped to running as part of its claim-and-enqueue transaction. - /// Duplicate concurrent execution is prevented at higher layers: the job - /// queue's `FOR UPDATE SKIP LOCKED` ensures only one worker can hold a - /// given resume job, and the scheduler's row-locking UPDATE (or event - /// consume) ensures only one resume job is enqueued per wake event. + /// Rejects `running → running`: a row already in `running` is being + /// executed by another handler. Re-entering races the live handler's + /// compensation state and step writes (#2 in the issues doc). The + /// scheduler claims rows from `(sleeping, waiting)` only; the job-queue's + /// `FOR UPDATE SKIP LOCKED` plus the per-run advisory lock taken in + /// `execute_workflow` keep concurrent resumes serialized end-to-end. + /// + /// The cancel bridge calls `force_claim_for_cancel` (which permits the + /// `running → running` transition for compensation) instead of going + /// through this path. async fn claim_for_execution(&self, run_id: Uuid) -> forge_core::Result<()> { - let result = sqlx::query!( - "UPDATE forge_workflow_runs SET status = 'running' WHERE id = $1 AND status IN ('pending', 'sleeping', 'waiting', 'running')", - run_id, + // Runtime query — schema unchanged, just a tighter status set; avoids + // touching `.sqlx/` for an internal helper. + #[allow(clippy::disallowed_methods)] + let result = sqlx::query( + "UPDATE forge_workflow_runs SET status = 'running' WHERE id = $1 AND status IN ('pending', 'sleeping', 'waiting')", ) + .bind(run_id) .execute(&self.pool) .await .map_err(forge_core::ForgeError::Database)?; @@ -709,6 +843,49 @@ impl WorkflowExecutor { Ok(()) } + /// Finalize a cancellation in the database. + /// + /// Writes the cancel reason to the dedicated `cancel_reason` column rather + /// than smuggling it into `error`. `error` carries the compensation + /// failure summary (if any) so operators have a single place to look for + /// remediation work. Sets `completed_at` so dashboards stop showing the + /// run as in-flight (#17 in issues doc). + async fn finalize_cancel( + &self, + run_id: Uuid, + reason: &str, + compensation_summary: Option<&str>, + ) -> forge_core::Result<()> { + // forge_workflow_runs is a runtime-owned system table; offline .sqlx + // cache doesn't always include it. + #[allow(clippy::disallowed_methods)] + let result = sqlx::query( + r#" + UPDATE forge_workflow_runs + SET status = 'failed', + cancel_reason = COALESCE(cancel_reason, $1), + error = $2, + completed_at = NOW() + WHERE id = $3 + AND status IN ('running', 'sleeping', 'waiting', 'pending') + "#, + ) + .bind(reason) + .bind(compensation_summary) + .bind(run_id) + .execute(&self.pool) + .await + .map_err(forge_core::ForgeError::Database)?; + + if result.rows_affected() == 0 { + return Err(forge_core::ForgeError::InvalidState(format!( + "Cannot finalize cancel for workflow {}: not in a valid state", + run_id + ))); + } + Ok(()) + } + async fn fail_workflow(&self, run_id: Uuid, error: &str) -> forge_core::Result<()> { let result = sqlx::query!( "UPDATE forge_workflow_runs SET status = 'failed', error = $1, completed_at = NOW() WHERE id = $2 AND status IN ('running', 'sleeping', 'waiting', 'pending')", @@ -737,9 +914,11 @@ impl WorkflowExecutor { ) -> forge_core::Result<()> { // Uses runtime query because the status value is dynamic and the // sqlx offline cache doesn't have an entry for this parameterized form. + // #19: also set `completed_at` so blocked runs leave the active list + // (otherwise dashboards treat them as in-flight indefinitely). #[allow(clippy::disallowed_methods)] sqlx::query( - "UPDATE forge_workflow_runs SET status = $1, error = $2 WHERE id = $3 AND status IN ('running', 'sleeping', 'waiting', 'pending')", + "UPDATE forge_workflow_runs SET status = $1, error = $2, completed_at = NOW() WHERE id = $3 AND status IN ('running', 'sleeping', 'waiting', 'pending')", ) .bind(status.as_str()) .bind(reason) diff --git a/crates/forge-runtime/src/workflow/registry.rs b/crates/forge-runtime/src/workflow/registry.rs index 58a37249..301bf13c 100644 --- a/crates/forge-runtime/src/workflow/registry.rs +++ b/crates/forge-runtime/src/workflow/registry.rs @@ -7,33 +7,11 @@ use std::sync::Arc; use chrono::{DateTime, Utc}; use forge_core::ForgeError; use forge_core::config::SignatureCheckMode; +use forge_core::util::normalize_handler_args as normalize_args; use forge_core::workflow::{ForgeWorkflow, WorkflowContext, WorkflowInfo}; -use serde_json::Value; use sqlx::PgPool; use uuid::Uuid; -// Converts null to {} so unit () and empty structs deserialize correctly. -// Unwraps one-level "args"/"input" envelopes (callers may use either format). -fn normalize_args(args: Value) -> Value { - let unwrapped = match &args { - Value::Object(map) if map.len() == 1 => { - if map.contains_key("args") { - map.get("args").cloned().unwrap_or(Value::Null) - } else if map.contains_key("input") { - map.get("input").cloned().unwrap_or(Value::Null) - } else { - args - } - } - _ => args, - }; - - match &unwrapped { - Value::Null => Value::Object(serde_json::Map::new()), - _ => unwrapped, - } -} - pub type BoxedWorkflowHandler = Arc< dyn Fn( &WorkflowContext, @@ -199,57 +177,51 @@ impl WorkflowRegistry { /// failing if a previously-registered name+version row has a different /// signature (the contract changed without a version bump). New rows are /// inserted, existing matching rows get their `status` refreshed. + /// + /// Each definition is upserted in a single transaction with + /// `INSERT ... ON CONFLICT DO UPDATE ... RETURNING workflow_signature` + /// so two nodes booting concurrently can't produce a generic + /// unique-violation in place of the helpful signature-mismatch message + /// (#16 in issues doc). pub async fn persist_definitions(&self, pool: &PgPool) -> forge_core::Result<()> { for info in self.definitions() { let status = info.status.as_str(); - let existing = sqlx::query!( + let mut tx = pool.begin().await.map_err(ForgeError::Database)?; + + // Atomic upsert: only the status is updated on conflict; signature + // is preserved so we can compare it to the incoming one without + // a separate SELECT round-trip. + #[allow(clippy::disallowed_methods)] + let returned: (String,) = sqlx::query_as( r#" - SELECT workflow_signature FROM forge_workflow_definitions - WHERE workflow_name = $1 AND workflow_version = $2 + INSERT INTO forge_workflow_definitions (workflow_name, workflow_version, workflow_signature, status) + VALUES ($1, $2, $3, $4) + ON CONFLICT (workflow_name, workflow_version) DO UPDATE SET status = EXCLUDED.status + RETURNING workflow_signature "#, - info.name, - info.version, ) - .fetch_optional(pool) + .bind(info.name) + .bind(info.version) + .bind(info.signature) + .bind(status) + .fetch_one(&mut *tx) .await .map_err(ForgeError::Database)?; - if let Some(row) = existing { - if row.workflow_signature != info.signature { - return Err(ForgeError::config(format!( - "Workflow '{}' version '{}' has a different signature than previously registered. \ - Persisted contract changed under the same version. \ - Expected signature: {}, got: {}. \ - Create a new version instead of modifying the existing one.", - info.name, info.version, row.workflow_signature, info.signature - ))); - } - sqlx::query!( - "UPDATE forge_workflow_definitions SET status = $3 WHERE workflow_name = $1 AND workflow_version = $2", - info.name, - info.version, - status, - ) - .execute(pool) - .await - .map_err(ForgeError::Database)?; - } else { - sqlx::query!( - r#" - INSERT INTO forge_workflow_definitions (workflow_name, workflow_version, workflow_signature, status) - VALUES ($1, $2, $3, $4) - "#, - info.name, - info.version, - info.signature, - status, - ) - .execute(pool) - .await - .map_err(ForgeError::Database)?; + if returned.0 != info.signature { + tx.rollback().await.map_err(ForgeError::Database)?; + return Err(ForgeError::config(format!( + "Workflow '{}' version '{}' has a different signature than previously registered. \ + Persisted contract changed under the same version. \ + Expected signature: {}, got: {}. \ + Create a new version instead of modifying the existing one.", + info.name, info.version, returned.0, info.signature + ))); } + tx.commit().await.map_err(ForgeError::Database)?; + tracing::debug!( workflow = info.name, version = info.version, @@ -377,51 +349,11 @@ impl Clone for WorkflowRegistry { mod tests { use super::*; use forge_core::workflow::WorkflowDefStatus; - use serde_json::json; - - // --- normalize_args mirrors the jobs/registry contract: null collapses - // to {} (so empty-struct inputs deserialize) and one-level `args`/`input` - // envelopes are unwrapped. Other shapes pass through unchanged. - #[test] - fn normalize_args_converts_null_to_empty_object() { - assert_eq!(normalize_args(json!(null)), json!({})); - } - - #[test] - fn normalize_args_keeps_empty_object_intact() { - assert_eq!(normalize_args(json!({})), json!({})); - } + use serde_json::Value; - #[test] - fn normalize_args_unwraps_args_envelope() { - assert_eq!(normalize_args(json!({"args": {"x": 1}})), json!({"x": 1})); - // null inside the envelope still collapses to {}. - assert_eq!(normalize_args(json!({"args": null})), json!({})); - } - - #[test] - fn normalize_args_unwraps_input_envelope() { - assert_eq!(normalize_args(json!({"input": [9, 8]})), json!([9, 8])); - } - - #[test] - fn normalize_args_keeps_other_single_key_objects_intact() { - assert_eq!(normalize_args(json!({"id": 7})), json!({"id": 7})); - } - - #[test] - fn normalize_args_keeps_multi_key_objects_intact() { - let v = json!({"a": 1, "b": 2}); - assert_eq!(normalize_args(v.clone()), v); - } - - #[test] - fn normalize_args_keeps_scalars_intact() { - assert_eq!(normalize_args(json!(42)), json!(42)); - assert_eq!(normalize_args(json!("ok")), json!("ok")); - assert_eq!(normalize_args(json!(true)), json!(true)); - } + // normalize_args contract is exercised via `forge_core::util` tests; this + // file now delegates to the shared helper to keep the two registries in sync. // ForgeWorkflow is sealed, so tests build entries directly through pub fields // with noop handlers — same insertion shape as register::. diff --git a/crates/forge-runtime/src/workflow/scheduler.rs b/crates/forge-runtime/src/workflow/scheduler.rs index b0bdbb91..f5a4d605 100644 --- a/crates/forge-runtime/src/workflow/scheduler.rs +++ b/crates/forge-runtime/src/workflow/scheduler.rs @@ -11,6 +11,29 @@ use crate::jobs::JobQueue; use crate::pg::{LeaderElection, PgNotifyBus}; use forge_core::Result; +/// Why a workflow is being resumed. Surfaced to the bridge / handler in the +/// `$workflow_resume` job args under the `resume_reason` key so replayed +/// `wait_for_event` calls can distinguish "event arrived" from "event +/// timeout" (and timer wakeups from event-driven ones). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum ResumeReason { + Timer, + EventArrived, + EventTimeout, + Cancel, +} + +impl ResumeReason { + fn as_str(self) -> &'static str { + match self { + Self::Timer => "timer", + Self::EventArrived => "event_arrived", + Self::EventTimeout => "event_timeout", + Self::Cancel => "cancel", + } + } +} + /// Configuration for the workflow scheduler. #[derive(Debug, Clone)] pub struct WorkflowSchedulerConfig { @@ -198,10 +221,20 @@ impl WorkflowScheduler { for workflow in workflows { if workflow.waiting_for_event.is_some() { - self.claim_and_resume(workflow.id, false, "event_timeout") - .await; + // Event-wait timed out: tag the resume so the bridge / handler + // can tell "event arrived" apart from "event timeout" + // (#5 in issues doc). Without this signal, replayed + // `wait_for_event` cannot tell which branch fired. + self.claim_and_resume( + workflow.id, + false, + "event_timeout", + ResumeReason::EventTimeout, + ) + .await; } else { - self.claim_and_resume(workflow.id, true, "timer").await; + self.claim_and_resume(workflow.id, true, "timer", ResumeReason::Timer) + .await; } } @@ -249,6 +282,7 @@ impl WorkflowScheduler { "from_sleep": false, "cancel": true, "reason": reason, + "resume_reason": ResumeReason::Cancel.as_str(), }); let job = crate::jobs::JobRecord::new( WORKFLOW_RESUME_JOB.to_string(), @@ -343,12 +377,19 @@ impl WorkflowScheduler { return Ok(()); } + // Move the run out of the wait state to 'pending' (not 'running'): + // the executor's `claim_for_execution` is the sole claimer and only + // accepts pending/sleeping/waiting -> running. Pre-claiming to + // 'running' here would make the enqueued resume job unclaimable, so + // the handler would never run and the workflow would hang. 'pending' + // mirrors the start path (a fresh run is 'pending' with a resume job + // enqueued) and is not re-scanned by the timer/event poll queries. #[allow(clippy::disallowed_methods)] let claimed = sqlx::query( r#" UPDATE forge_workflow_runs SET wake_at = NULL, waiting_for_event = NULL, event_timeout_at = NULL, - suspended_at = NULL, status = 'running' + suspended_at = NULL, status = 'pending' WHERE id = $1 AND status IN ('sleeping', 'waiting') "#, ) @@ -365,6 +406,7 @@ impl WorkflowScheduler { let input = serde_json::json!({ "run_id": workflow_run_id.to_string(), "from_sleep": false, + "resume_reason": ResumeReason::EventArrived.as_str(), }); let job = crate::jobs::JobRecord::new( WORKFLOW_RESUME_JOB.to_string(), @@ -399,18 +441,29 @@ impl WorkflowScheduler { /// Atomically claim a workflow and enqueue a resume job in a single transaction. /// If the claim fails (row already claimed), the transaction is rolled back /// and no resume job is enqueued. - async fn claim_and_resume(&self, workflow_run_id: Uuid, from_sleep: bool, trigger: &str) { + async fn claim_and_resume( + &self, + workflow_run_id: Uuid, + from_sleep: bool, + trigger: &str, + reason: ResumeReason, + ) { let result: std::result::Result<(), sqlx::Error> = async { let mut tx = self.pool.begin().await?; // Runtime query: rewritten for single-transaction claim+resume; // convert to query!() after next `cargo sqlx prepare`. + // Move to 'pending' (not 'running'): the executor's + // `claim_for_execution` is the sole claimer and rejects 'running'. + // Pre-claiming to 'running' here would leave the enqueued resume job + // unable to claim the run, hanging timer/sleep resumes (mirrors the + // event path in `consume_claim_and_resume`). #[allow(clippy::disallowed_methods)] let claimed = sqlx::query( r#" UPDATE forge_workflow_runs SET wake_at = NULL, waiting_for_event = NULL, event_timeout_at = NULL, - suspended_at = NULL, status = 'running' + suspended_at = NULL, status = 'pending' WHERE id = $1 AND status IN ('sleeping', 'waiting') "#, ) @@ -426,6 +479,7 @@ impl WorkflowScheduler { let input = serde_json::json!({ "run_id": workflow_run_id.to_string(), "from_sleep": from_sleep, + "resume_reason": reason.as_str(), }); let job = crate::jobs::JobRecord::new( WORKFLOW_RESUME_JOB.to_string(), diff --git a/crates/forge/Cargo.toml b/crates/forge/Cargo.toml index 0c68069b..e9301819 100644 --- a/crates/forge/Cargo.toml +++ b/crates/forge/Cargo.toml @@ -67,6 +67,7 @@ opentelemetry-otlp = { workspace = true, optional = true } rust-embed = { workspace = true, optional = true } mime_guess = { workspace = true, optional = true } +tempfile = { workspace = true } [features] # Default to `full` so existing apps upgrade transparently. Users who want a @@ -148,7 +149,6 @@ embedded-frontend = ["dep:rust-embed", "dep:mime_guess"] nix = { version = "0.29", features = ["signal", "hostname"] } [dev-dependencies] -tempfile = { workspace = true } trybuild = { workspace = true } [build-dependencies] diff --git a/crates/forge/build.rs b/crates/forge/build.rs index fb886eeb..e0ab3ce9 100644 --- a/crates/forge/build.rs +++ b/crates/forge/build.rs @@ -1,81 +1,82 @@ use std::env; use std::fs; +use std::io; use std::path::{Path, PathBuf}; -fn main() { +fn main() -> io::Result<()> { println!("cargo:rerun-if-changed=build.rs"); - let manifest_dir = - PathBuf::from(env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR must be set")); - let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR must be set")); + let manifest_dir = PathBuf::from( + env::var("CARGO_MANIFEST_DIR") + .map_err(|e| io::Error::other(format!("CARGO_MANIFEST_DIR: {e}")))?, + ); + let out_dir = + PathBuf::from(env::var("OUT_DIR").map_err(|e| io::Error::other(format!("OUT_DIR: {e}")))?); let embedded_examples_dir = out_dir.join("examples"); if embedded_examples_dir.exists() { - fs::remove_dir_all(&embedded_examples_dir) - .expect("failed to clear generated embedded examples"); + fs::remove_dir_all(&embedded_examples_dir)?; } - fs::create_dir_all(&embedded_examples_dir) - .expect("failed to create generated embedded examples"); + fs::create_dir_all(&embedded_examples_dir)?; - if let Some(examples_dir) = find_examples_dir(&manifest_dir) { - build_bundle_from_examples(&examples_dir, &embedded_examples_dir); - return; + if let Some(examples_dir) = find_examples_dir(&manifest_dir)? { + build_bundle_from_examples(&examples_dir, &embedded_examples_dir)?; + return Ok(()); } let archive_path = manifest_dir.join("generated/examples.tar"); println!("cargo:rerun-if-changed={}", archive_path.display()); if archive_path.exists() { - let archive = - fs::File::open(&archive_path).expect("failed to open generated examples archive"); + let archive = fs::File::open(&archive_path)?; let mut archive = tar::Archive::new(archive); - archive - .unpack(&embedded_examples_dir) - .expect("failed to unpack generated examples archive"); - return; + archive.unpack(&embedded_examples_dir)?; + return Ok(()); } - unreachable!("could not find examples directory or generated examples archive"); + Err(io::Error::other( + "could not find examples directory or generated examples archive", + )) } -fn build_bundle_from_examples(examples_dir: &Path, bundle_dir: &Path) { - for framework_dir in fs::read_dir(examples_dir).expect("failed to read examples directory") { - let framework_dir = framework_dir.expect("failed to read framework entry"); +fn build_bundle_from_examples(examples_dir: &Path, bundle_dir: &Path) -> io::Result<()> { + for framework_dir in fs::read_dir(examples_dir)? { + let framework_dir = framework_dir?; let framework_path = framework_dir.path(); if !framework_path.is_dir() { continue; } - let framework_name = framework_path - .file_name() - .and_then(|name| name.to_str()) - .expect("framework directory must have utf-8 name"); + let Some(framework_name) = framework_path.file_name().and_then(|n| n.to_str()) else { + continue; + }; if !framework_name.starts_with("with-") { continue; } - for template_dir in fs::read_dir(&framework_path).expect("failed to read framework dir") { - let template_dir = template_dir.expect("failed to read template entry"); + for template_dir in fs::read_dir(&framework_path)? { + let template_dir = template_dir?; let template_path = template_dir.path(); if !template_path.is_dir() { continue; } + let Some(template_name) = template_path.file_name() else { + continue; + }; + copy_template_tree( &template_path, - &bundle_dir.join(framework_name).join( - template_path - .file_name() - .expect("template directory must have a name"), - ), - ); + &bundle_dir.join(framework_name).join(template_name), + )?; } } + Ok(()) } -fn find_examples_dir(manifest_dir: &Path) -> Option { +fn find_examples_dir(manifest_dir: &Path) -> io::Result> { let candidates = [ manifest_dir.join("../../examples"), manifest_dir.join("examples"), @@ -83,15 +84,15 @@ fn find_examples_dir(manifest_dir: &Path) -> Option { for candidate in candidates { if candidate.is_dir() { - register_rerun_paths(&candidate); - return Some(candidate); + register_rerun_paths(&candidate)?; + return Ok(Some(candidate)); } } - None + Ok(None) } -fn register_rerun_paths(root: &Path) { +fn register_rerun_paths(root: &Path) -> io::Result<()> { println!("cargo:rerun-if-changed={}", root.display()); if let Ok(entries) = fs::read_dir(root) { @@ -99,22 +100,21 @@ fn register_rerun_paths(root: &Path) { let path = entry.path(); println!("cargo:rerun-if-changed={}", path.display()); if path.is_dir() { - register_rerun_paths(&path); + register_rerun_paths(&path)?; } } } + Ok(()) } -fn copy_template_tree(src: &Path, dest: &Path) { - fs::create_dir_all(dest).expect("failed to create template directory"); - copy_dir_contents(src, dest, Path::new("")); +fn copy_template_tree(src: &Path, dest: &Path) -> io::Result<()> { + fs::create_dir_all(dest)?; + copy_dir_contents(src, dest, Path::new("")) } -fn copy_dir_contents(src: &Path, dest: &Path, relative: &Path) { - let entries = fs::read_dir(src).expect("failed to read template source directory"); - - for entry in entries { - let entry = entry.expect("failed to read template source entry"); +fn copy_dir_contents(src: &Path, dest: &Path, relative: &Path) -> io::Result<()> { + for entry in fs::read_dir(src)? { + let entry = entry?; let entry_path = entry.path(); let entry_name = entry.file_name(); let relative_path = relative.join(&entry_name); @@ -125,24 +125,20 @@ fn copy_dir_contents(src: &Path, dest: &Path, relative: &Path) { let dest_path = dest.join(&entry_name); if entry_path.is_dir() { - fs::create_dir_all(&dest_path).expect("failed to create bundled directory"); - copy_dir_contents(&entry_path, &dest_path, &relative_path); + fs::create_dir_all(&dest_path)?; + copy_dir_contents(&entry_path, &dest_path, &relative_path)?; } else { if let Some(parent) = dest_path.parent() { - fs::create_dir_all(parent).expect("failed to create bundled file parent"); + fs::create_dir_all(parent)?; } - fs::copy(&entry_path, &dest_path).expect("failed to copy template file"); + fs::copy(&entry_path, &dest_path)?; } } + Ok(()) } fn should_exclude(relative_path: &Path) -> bool { - const EXACT_FILES: &[&str] = &[ - ".forge-dev-integration.log", - "package-lock.json", - "bun.lock", - "Cargo.lock", - ]; + const EXACT_FILES: &[&str] = &[".forge-dev-integration.log"]; const PATH_COMPONENTS: &[&str] = &[ ".git", "pg_data", diff --git a/crates/forge/src/auto_register.rs b/crates/forge/src/auto_register.rs index 7250eb17..999655fb 100644 --- a/crates/forge/src/auto_register.rs +++ b/crates/forge/src/auto_register.rs @@ -1,5 +1,6 @@ //! Automatic function registration via the `inventory` crate. +use forge_core::error::{ForgeError, Result}; use forge_runtime::function::FunctionRegistry; #[cfg(feature = "cron")] @@ -39,9 +40,52 @@ pub struct AutoHandler(pub fn(&mut HandlerRegistries)); inventory::collect!(AutoHandler); -/// Register all auto-discovered handlers. -pub fn auto_register_all(registries: &mut HandlerRegistries) { +/// Register all auto-discovered handlers, failing if any handler name collides. +/// +/// Duplicate detection: the per-kind registries store handlers in `HashMap`s +/// keyed on the handler name, so a duplicate (e.g. two `#[query] pub async fn +/// get_user`s in different modules) would silently overwrite. We snapshot the +/// function-name set before and after each closure runs and surface any +/// collision as a startup error. +pub fn auto_register_all(registries: &mut HandlerRegistries) -> Result<()> { + use std::collections::HashSet; + + let mut seen: HashSet = registries + .functions + .function_names() + .map(|s| s.to_string()) + .collect(); + for entry in inventory::iter:: { + let before = registries.functions.len(); (entry.0)(registries); + let after = registries.functions.len(); + + // The closure might register zero functions (job/cron/daemon/webhook/mcp_tool + // bridges) — only validate when the function registry actually grew or + // when an existing entry got overwritten in place. + let current: HashSet = registries + .functions + .function_names() + .map(|s| s.to_string()) + .collect(); + + let newly_added: Vec = current.difference(&seen).cloned().collect(); + if !newly_added.is_empty() { + seen.extend(newly_added); + } else if after <= before { + // No net growth and no new names — either a non-function handler or + // an overwrite. Detect overwrite by checking the entry count. + let registered_count = registries.functions.len(); + if registered_count < seen.len() { + return Err(ForgeError::config( + "duplicate handler name detected during auto-registration: \ + two #[forge::*] handlers resolve to the same function name. \ + Use `name = \"...\"` in one of the macro attributes to disambiguate.", + )); + } + } } + + Ok(()) } diff --git a/crates/forge/src/cli/check/frontend.rs b/crates/forge/src/cli/check/frontend.rs index 183c554c..bd2324ca 100644 --- a/crates/forge/src/cli/check/frontend.rs +++ b/crates/forge/src/cli/check/frontend.rs @@ -151,8 +151,9 @@ impl CheckCommand { || !registry.all_enums().is_empty() || !registry.all_functions().is_empty(); - let tmp_dir = frontend_dir.join(format!("forge-check-{}", std::process::id())); - let tmp_output = tmp_dir.join("bindings"); + // tempfile::tempdir() avoids PID-reuse collisions on long-lived CI containers. + let tmp_handle = tempfile::tempdir_in(frontend_dir)?; + let tmp_output = tmp_handle.path().join("bindings"); std::fs::create_dir_all(&tmp_output)?; let tmp_output_str = tmp_output.to_string_lossy().to_string(); @@ -164,12 +165,7 @@ impl CheckCommand { force: true, }); - let cleanup = || { - let _ = std::fs::remove_dir_all(&tmp_dir); - }; - if let Err(e) = gen_result { - cleanup(); result.warn( &format!("Could not regenerate bindings: {}", e), "Run 'forge generate' to check manually", @@ -180,7 +176,6 @@ impl CheckCommand { if let Err(e) = format_generated_bindings_for_check(target, frontend_dir, output_path, &tmp_output) { - cleanup(); result.warn( &format!("Could not format regenerated bindings: {}", e), "Run 'forge generate --force' to restore generated bindings", @@ -218,8 +213,6 @@ impl CheckCommand { } } - cleanup(); - if modified.is_empty() && missing.is_empty() { result.pass("Generated bindings are up to date"); } else { diff --git a/crates/forge/src/cli/check/project.rs b/crates/forge/src/cli/check/project.rs index f728b4dc..bbf495b9 100644 --- a/crates/forge/src/cli/check/project.rs +++ b/crates/forge/src/cli/check/project.rs @@ -47,24 +47,25 @@ impl CheckCommand { }; let filename = file_name.to_string_lossy(); - let name_valid = filename - .split('_') - .next() - .map(|prefix| prefix.chars().all(|c| c.is_ascii_digit())) - .unwrap_or(false); + // Require `_.sql` with both sides non-empty. + // Empty prefix (`_initial.sql`) or empty tail (`0001_.sql`) + // pass naive splits and sort surprisingly at runtime. + let stem = filename.strip_suffix(".sql").unwrap_or(&filename); + let name_valid = match stem.split_once('_') { + Some((prefix, tail)) => { + !prefix.is_empty() + && !tail.is_empty() + && prefix.chars().all(|c| c.is_ascii_digit()) + } + None => false, + }; if !name_valid { issues.push(format!("{} - should be NNNN_name.sql", filename)); continue; } - // Check for @up marker - let content = std::fs::read_to_string(&path)?; - if content.contains("-- @up") { - valid_count += 1; - } else { - issues.push(format!("{} - missing '-- @up' marker", filename)); - } + valid_count += 1; } } @@ -82,7 +83,7 @@ impl CheckCommand { issues.len(), migration_count ), - "Fix migration file naming or add '-- @up' marker", + "Use NNNN_name.sql with a numeric prefix and a non-empty name", ); for issue in issues.iter().take(3) { result.info(issue); diff --git a/crates/forge/src/cli/check/sqlx.rs b/crates/forge/src/cli/check/sqlx.rs index 98130e50..b3480f19 100644 --- a/crates/forge/src/cli/check/sqlx.rs +++ b/crates/forge/src/cli/check/sqlx.rs @@ -116,7 +116,8 @@ pub(super) fn file_uses_sqlx_macros(content: &str) -> bool { "sqlx::query_file!(", "sqlx::query_file_as!(", ]; - content.lines().any(|line| { + let stripped = strip_comments(content); + stripped.lines().any(|line| { let code = match line.split_once("//") { Some((before, _)) => before, None => line, @@ -125,6 +126,33 @@ pub(super) fn file_uses_sqlx_macros(content: &str) -> bool { }) } +/// Replace `/* ... */` block comments with whitespace of the same length so +/// line/column offsets are preserved. Nested comments are not supported (Rust +/// allows them but they are vanishingly rare and don't affect the heuristic). +#[allow(clippy::indexing_slicing)] +fn strip_comments(content: &str) -> String { + let bytes = content.as_bytes(); + let mut out = String::with_capacity(content.len()); + let mut i = 0; + while i < bytes.len() { + if i + 1 < bytes.len() && bytes[i] == b'/' && bytes[i + 1] == b'*' { + // Skip block comment, preserving newlines so line-comment stripping later still works. + i += 2; + while i + 1 < bytes.len() && !(bytes[i] == b'*' && bytes[i + 1] == b'/') { + if bytes[i] == b'\n' { + out.push('\n'); + } + i += 1; + } + i = (i + 2).min(bytes.len()); + } else { + out.push(bytes[i] as char); + i += 1; + } + } + out +} + pub(super) fn inspect_sqlx_cache(sqlx_dir: &Path) -> Result { if !sqlx_dir.exists() { return Ok(SqlxCacheCheck::Missing); diff --git a/crates/forge/src/cli/check/system_tables.rs b/crates/forge/src/cli/check/system_tables.rs index eb9cff87..a72f341e 100644 --- a/crates/forge/src/cli/check/system_tables.rs +++ b/crates/forge/src/cli/check/system_tables.rs @@ -1,16 +1,48 @@ use anyhow::Result; use std::path::Path; +// Derived from crates/forge-runtime/migrations/system/v00*_*.sql. Keep in sync +// when a new system table is added there; `forge check` must fail closed when +// a user migration shadows a runtime-owned name. pub(super) const RESERVED_SYSTEM_TABLES: &[&str] = &[ - "forge_jobs", - "forge_workflow_runs", - "forge_workflow_definitions", + "forge_admin_audit", + "forge_change_log", "forge_cron_runs", - "forge_system_migrations", + "forge_daemons", + "forge_jobs", + "forge_jobs_history", + "forge_kv", + "forge_kv_counters", + "forge_leaders", + "forge_nodes", + "forge_oauth_clients", + "forge_oauth_codes", + "forge_paused_queues", + "forge_rate_limits", "forge_refresh_tokens", + "forge_signals_daily_rollup", "forge_signals_events", + "forge_signals_hourly_stats", + "forge_signals_sessions", + "forge_signals_users", + "forge_system_migrations", + "forge_webhook_events", + "forge_workflow_definitions", + "forge_workflow_events", + "forge_workflow_runs", + "forge_workflow_state", + "forge_workflow_steps", ]; +/// System tables a handler may legitimately write to directly. +/// +/// `forge_workflow_events` is the workflow event inbox: a handler delivers an +/// external event to a `ctx.wait_for_event(...)` workflow by inserting a row +/// (there is no higher-level API for this, and the runtime's own harness does +/// the same). It stays in `RESERVED_SYSTEM_TABLES` for the migration-shadow +/// check, but writing to it is a supported pattern, not a leak. +const HANDLER_WRITABLE_SYSTEM_TABLES: &[&str] = &["forge_workflow_events"]; + pub(super) fn scan_system_table_writes( dir: &Path, out: &mut Vec<(std::path::PathBuf, &'static str)>, @@ -33,6 +65,9 @@ pub(super) fn scan_system_table_writes( let lower = content.to_ascii_lowercase(); for table in RESERVED_SYSTEM_TABLES { + if HANDLER_WRITABLE_SYSTEM_TABLES.contains(table) { + continue; + } let needles = [ format!("insert into {table}"), format!("update {table}"), @@ -46,3 +81,37 @@ pub(super) fn scan_system_table_writes( } Ok(()) } + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + + #[test] + fn flags_state_table_writes_but_allows_the_workflow_event_inbox() { + let dir = tempfile::tempdir().unwrap(); + std::fs::write( + dir.path().join("jobs.rs"), + r#"sqlx::query("INSERT INTO forge_jobs (id) VALUES ($1)")"#, + ) + .unwrap(); + std::fs::write( + dir.path().join("events.rs"), + r#"sqlx::query("INSERT INTO forge_workflow_events (id) VALUES ($1)")"#, + ) + .unwrap(); + + let mut out = Vec::new(); + scan_system_table_writes(dir.path(), &mut out).unwrap(); + let flagged: Vec<&str> = out.iter().map(|(_, t)| *t).collect(); + + assert!( + flagged.contains(&"forge_jobs"), + "direct write to a state table must be flagged" + ); + assert!( + !flagged.contains(&"forge_workflow_events"), + "the workflow event inbox is a supported handler write target" + ); + } +} diff --git a/crates/forge/src/cli/doctor.rs b/crates/forge/src/cli/doctor.rs index b0da0456..8d45fcdf 100644 --- a/crates/forge/src/cli/doctor.rs +++ b/crates/forge/src/cli/doctor.rs @@ -87,7 +87,6 @@ impl DoctorCommand { if let Some(ref root) = root { check_forge_toml(&mut report, root); check_sqlx_cache_freshness(&mut report, root); - check_latest_migration_markers(&mut report, root); } println!(); @@ -348,61 +347,6 @@ fn check_sqlx_cache_freshness(report: &mut Report, root: &Path) { } } -fn check_latest_migration_markers(report: &mut Report, root: &Path) { - let dir = root.join("migrations"); - let entries = match std::fs::read_dir(&dir) { - Ok(e) => e, - Err(_) => { - report.record(CheckStatus::Skip, "migrations/", "no migrations/ dir", None); - return; - } - }; - let mut latest: Option = None; - for entry in entries.flatten() { - let p = entry.path(); - if p.extension().and_then(|s| s.to_str()) == Some("sql") - && latest.as_ref().map(|l| p > *l).unwrap_or(true) - { - latest = Some(p); - } - } - let Some(path) = latest else { - report.record(CheckStatus::Skip, "migrations/", "empty", None); - return; - }; - let content = match std::fs::read_to_string(&path) { - Ok(c) => c, - Err(e) => { - report.record( - CheckStatus::Fail, - "migrations/", - &format!("read error: {e}"), - None, - ); - return; - } - }; - let name = path - .file_name() - .and_then(|s| s.to_str()) - .unwrap_or("(unknown)"); - if content.contains("-- @up") || !content.trim().is_empty() { - report.record( - CheckStatus::Ok, - "latest migration", - &format!("{name} present"), - None, - ); - } else { - report.record( - CheckStatus::Fail, - "latest migration", - &format!("{name} is empty"), - Some("Migration file must contain SQL"), - ); - } -} - fn parse_pg_host_port(url: &str) -> Option<(String, u16)> { let rest = url .strip_prefix("postgres://") @@ -445,7 +389,16 @@ fn required_rust_version(root: Option<&Path>) -> String { fn version_meets(found: &str, required: &str) -> bool { fn parts(s: &str) -> Vec { - s.split('.').filter_map(|x| x.parse().ok()).collect() + s.split('.') + .map(|x| { + // Strip prerelease/build suffixes ("1.93.0-beta", "1.93+abc") + // by truncating at the first non-digit character per component. + let digits: String = x.chars().take_while(|c| c.is_ascii_digit()).collect(); + digits.parse::().ok() + }) + .take_while(|p| p.is_some()) + .flatten() + .collect() } let f = parts(found); let r = parts(required); @@ -558,14 +511,12 @@ mod tests { } #[test] - fn version_meets_handles_trailing_garbage_after_full_match() { - // Dotted segments that don't parse are dropped, so any trailing tag - // attached to a later segment gets truncated. As long as enough - // numeric components match before the bad one, the comparison passes. + fn version_meets_strips_prerelease_suffixes() { + // Per-component digit truncation handles standard suffixes. assert!(version_meets("1.92.0-nightly", "1.92")); - // But if the bad token replaces a required component, the missing - // component is treated as 0 and the comparison fails. - assert!(!version_meets("1.93-beta", "1.92")); + assert!(version_meets("1.93-beta", "1.92")); + assert!(version_meets("1.93.0+build.5", "1.92")); + assert!(!version_meets("1.91.0-stable", "1.92")); } #[test] diff --git a/crates/forge/src/cli/frontend_codegen.rs b/crates/forge/src/cli/frontend_codegen.rs index de7e4c7e..921128bd 100644 --- a/crates/forge/src/cli/frontend_codegen.rs +++ b/crates/forge/src/cli/frontend_codegen.rs @@ -67,7 +67,21 @@ fn format_generated_rust_bindings(output_dir: &Path) -> Result<()> { rustfmt.arg(file); } - let _ = rustfmt.status(); + match rustfmt.status() { + Ok(status) if status.success() => {} + Ok(status) => { + eprintln!( + "warning: rustfmt exited with status {} while formatting generated Dioxus bindings; output left unformatted", + status + ); + } + Err(e) => { + eprintln!( + "warning: could not run rustfmt to format generated Dioxus bindings: {} (install rustfmt or run 'rustup component add rustfmt')", + e + ); + } + } Ok(()) } diff --git a/crates/forge/src/cli/generate.rs b/crates/forge/src/cli/generate.rs index 36a1f53e..4be753ab 100644 --- a/crates/forge/src/cli/generate.rs +++ b/crates/forge/src/cli/generate.rs @@ -90,10 +90,13 @@ impl GenerateCommand { || !registry.all_enums().is_empty() || !registry.all_functions().is_empty(); + // Serialize the schema up front, but only write it to disk after the + // binding generator succeeds. Otherwise a failed `forge generate` would + // leave `forge.schema.json` describing a state that doesn't match the + // bindings on disk. let schema_path = Path::new("forge.schema.json"); let schema_json = forge_codegen::emit_schema_json(®istry) .map_err(|e| anyhow::anyhow!("Failed to serialize schema: {}", e))?; - std::fs::write(schema_path, &schema_json)?; eprint!( " Generating {} bindings...", @@ -106,6 +109,7 @@ impl GenerateCommand { has_schema, force: self.force, })?; + std::fs::write(schema_path, &schema_json)?; eprintln!(" done"); // Sync the frontend toolchain. For SvelteKit this runs `svelte-kit diff --git a/crates/forge/src/cli/migrate.rs b/crates/forge/src/cli/migrate.rs index 3a79dcdf..551b3885 100644 --- a/crates/forge/src/cli/migrate.rs +++ b/crates/forge/src/cli/migrate.rs @@ -33,7 +33,16 @@ pub enum MigrateAction { Status, /// Generate .sqlx/ offline cache for compile-time query checking. - Prepare, + Prepare { + /// Apply pending migrations before generating the cache. Without this, + /// prepare refuses to mutate a non-local DATABASE_URL unattended. + #[arg(long)] + with_up: bool, + + /// Skip the interactive confirmation prompt. + #[arg(short = 'y', long)] + yes: bool, + }, } impl MigrateCommand { @@ -84,10 +93,39 @@ impl MigrateCommand { println!(); } - MigrateAction::Prepare => { + MigrateAction::Prepare { with_up, yes } => { ui::section("FORGE Prepare"); - if !available.is_empty() { + let database_url_for_check = config.database.url().to_string(); + let is_local = database_url_is_local(&database_url_for_check); + + let pending = runner.status(&available).await?.pending; + + if !pending.is_empty() { + if !with_up { + let masked = mask_database_url(&database_url_for_check); + println!( + " {} {} pending migration(s) detected.", + ui::warn(), + pending.len() + ); + println!(" Target DATABASE_URL: {masked}"); + if !is_local && !yes { + anyhow::bail!( + "Refusing to run pending migrations against a non-local database \ + without explicit consent.\n\n \ + Re-run with `--with-up` to apply, or `--yes` to acknowledge the \ + target. Set DATABASE_URL to a localhost instance for unattended \ + use." + ); + } + if !yes { + anyhow::bail!( + "Refusing to auto-run migrations from `forge migrate prepare`.\n \ + Pass `--with-up` to apply, or run `forge migrate up` separately." + ); + } + } println!(" {} Running pending migrations...", ui::step()); runner.run(available).await?; println!(" {} Migrations complete", ui::ok()); @@ -190,3 +228,78 @@ impl MigrateCommand { Ok(()) } } + +/// True when the URL clearly targets a developer-local Postgres (no risk of +/// stomping a shared environment by accident). +fn database_url_is_local(url: &str) -> bool { + let rest = match url + .strip_prefix("postgres://") + .or_else(|| url.strip_prefix("postgresql://")) + { + Some(r) => r, + None => return false, + }; + let host_section = rest.rsplit_once('@').map(|(_, r)| r).unwrap_or(rest); + let host_port = host_section + .split(['/', '?']) + .next() + .unwrap_or(host_section); + const LOCAL: &[&str] = &["localhost", "127.0.0.1", "::1", "0.0.0.0"]; + if LOCAL.contains(&host_port) { + return true; + } + // Strip trailing :port only when the suffix is all-digits and the + // remaining host has no `:` (rules out IPv6 host without brackets). + let host = match host_port.rsplit_once(':') { + Some((h, p)) + if !p.is_empty() && p.chars().all(|c| c.is_ascii_digit()) && !h.contains(':') => + { + h + } + _ => host_port, + }; + LOCAL.contains(&host) +} + +/// Replace the password in a `postgres[ql]://user:password@host…` URL with `***`. +fn mask_database_url(url: &str) -> String { + let (scheme, rest) = match url.split_once("://") { + Some(pair) => pair, + None => return url.to_string(), + }; + let Some((userinfo, host)) = rest.rsplit_once('@') else { + return url.to_string(); + }; + let masked_userinfo = match userinfo.split_once(':') { + Some((user, _pw)) => format!("{user}:***"), + None => userinfo.to_string(), + }; + format!("{scheme}://{masked_userinfo}@{host}") +} + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + + #[test] + fn database_url_is_local_basic() { + assert!(database_url_is_local("postgres://u:p@localhost:5432/db")); + assert!(database_url_is_local("postgres://u:p@127.0.0.1/db")); + assert!(database_url_is_local("postgresql://u@::1/db")); + assert!(!database_url_is_local("postgres://u:p@db.prod:5432/db")); + assert!(!database_url_is_local("not-a-url")); + } + + #[test] + fn mask_database_url_basic() { + assert_eq!( + mask_database_url("postgres://u:secret@host:5432/db"), + "postgres://u:***@host:5432/db" + ); + assert_eq!( + mask_database_url("postgres://host/db"), + "postgres://host/db" + ); + } +} diff --git a/crates/forge/src/cli/new.rs b/crates/forge/src/cli/new.rs index 5b78c095..aceb840c 100644 --- a/crates/forge/src/cli/new.rs +++ b/crates/forge/src/cli/new.rs @@ -77,6 +77,36 @@ pub(super) fn extract_project_name(input: &str) -> String { .to_string() } +/// Reject names that would escape the cwd, inject shell metacharacters into +/// generated configs / commands, or produce a malformed Cargo / npm package. +/// +/// Allow only `[a-zA-Z0-9_-]`, must start with alphanumeric. +pub(super) fn validate_project_name(name: &str) -> Result<()> { + let trimmed = name.trim(); + if trimmed.is_empty() { + anyhow::bail!("project name cannot be empty or whitespace"); + } + if name != trimmed { + anyhow::bail!("project name cannot have leading or trailing whitespace"); + } + let Some(first) = name.chars().next() else { + anyhow::bail!("project name cannot be empty"); + }; + if !first.is_ascii_alphanumeric() { + anyhow::bail!( + "invalid project name '{name}': must start with a letter or digit (got '{first}')" + ); + } + for c in name.chars() { + if !(c.is_ascii_alphanumeric() || c == '_' || c == '-') { + anyhow::bail!( + "invalid project name '{name}': only [a-zA-Z0-9_-] allowed (rejected character: {c:?})" + ); + } + } + Ok(()) +} + fn is_git_available() -> bool { StdCommand::new("git") .arg("--version") @@ -160,8 +190,6 @@ fn run_formatters(dir: &Path, frontend: FrontendTarget) -> Result<()> { } fn generate_cargo_lockfile(dir: &Path, frontend: FrontendTarget) -> Result<()> { - println!(" {} Generating Cargo.lock...", ui::step()); - if !matches!(StdCommand::new("cargo").arg("--version").output(), Ok(o) if o.status.success()) { eprintln!( " {} cargo not found, skipping lockfile generation", @@ -170,6 +198,22 @@ fn generate_cargo_lockfile(dir: &Path, frontend: FrontendTarget) -> Result<()> { return Ok(()); } + generate_one_lockfile(dir, "Cargo.lock")?; + if frontend == FrontendTarget::Dioxus { + generate_one_lockfile(&dir.join("frontend"), "frontend/Cargo.lock")?; + } + + Ok(()) +} + +fn generate_one_lockfile(dir: &Path, label: &str) -> Result<()> { + if dir.join("Cargo.lock").exists() { + println!(" {} {label} already present, skipping", ui::ok()); + return Ok(()); + } + + println!(" {} Generating {label}...", ui::step()); + let output = StdCommand::new("cargo") .args(["generate-lockfile"]) .current_dir(dir) @@ -178,33 +222,14 @@ fn generate_cargo_lockfile(dir: &Path, frontend: FrontendTarget) -> Result<()> { if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); eprintln!( - " {} Failed to generate Cargo.lock: {}", + " {} Failed to generate {label}: {}", ui::warn(), stderr.trim() ); return Ok(()); } - println!(" {} Cargo.lock generated", ui::ok()); - - if frontend == FrontendTarget::Dioxus { - let output = StdCommand::new("cargo") - .args(["generate-lockfile"]) - .current_dir(dir.join("frontend")) - .output()?; - - if output.status.success() { - println!(" {} frontend/Cargo.lock generated", ui::ok()); - } else { - let stderr = String::from_utf8_lossy(&output.stderr); - eprintln!( - " {} Failed to generate frontend/Cargo.lock: {}", - ui::warn(), - stderr.trim() - ); - } - } - + println!(" {} {label} generated", ui::ok()); Ok(()) } @@ -520,6 +545,7 @@ impl NewCommand { let template = load_template_definition(template_id)?; let project_name = extract_project_name(&self.name); + validate_project_name(&project_name)?; let project_dir = self.output.as_ref().unwrap_or(&self.name); let path = Path::new(project_dir); @@ -982,6 +1008,56 @@ mod tests { assert!(!output.contains("1.0.0")); } + #[test] + fn validate_project_name_accepts_simple_names() { + assert!(validate_project_name("my-app").is_ok()); + assert!(validate_project_name("my_app").is_ok()); + assert!(validate_project_name("App1").is_ok()); + assert!(validate_project_name("a").is_ok()); + } + + #[test] + fn validate_project_name_rejects_path_traversal_and_separators() { + assert!(validate_project_name("..").is_err()); + assert!(validate_project_name("../etc").is_err()); + assert!(validate_project_name("../../etc").is_err()); + assert!(validate_project_name("/abs").is_err()); + assert!(validate_project_name("a/b").is_err()); + assert!(validate_project_name("~root").is_err()); + } + + #[test] + fn validate_project_name_rejects_shell_metacharacters() { + for bad in [ + "a;rm", "a&b", "a|b", "a`b`", "a$()", "a$x", "a b", "a\"b", "a'b", "a\nb", + ] { + assert!(validate_project_name(bad).is_err(), "should reject {bad:?}"); + } + } + + #[test] + fn validate_project_name_rejects_empty_and_whitespace() { + assert!(validate_project_name("").is_err()); + assert!(validate_project_name(" ").is_err()); + assert!(validate_project_name("\t").is_err()); + assert!(validate_project_name(" leading").is_err()); + assert!(validate_project_name("trailing ").is_err()); + } + + #[test] + fn validate_project_name_rejects_leading_hyphen_or_digit_underscore() { + // Leading hyphen would be parsed as a clap flag. + assert!(validate_project_name("-flag").is_err()); + // Leading underscore — must start with alphanumeric per the validator. + assert!(validate_project_name("_hidden").is_err()); + } + + #[test] + fn validate_project_name_rejects_unicode_only() { + assert!(validate_project_name("日本語").is_err()); + assert!(validate_project_name("café").is_err()); + } + #[test] fn test_invalid_template_error_lists_supported_templates() { let error = invalid_template_error("with-svelte/unknown"); diff --git a/crates/forge/src/cli/template_catalog.rs b/crates/forge/src/cli/template_catalog.rs index 0da33321..5d5ceceb 100644 --- a/crates/forge/src/cli/template_catalog.rs +++ b/crates/forge/src/cli/template_catalog.rs @@ -169,11 +169,17 @@ fn collect_directories(dir: &Dir<'_>, prefix: &Path, directories: &mut Vec bool { - if path == Path::new(entry) { + let entry_path = Path::new(entry); + if path == entry_path { return true; } - path.components() - .any(|component| component.as_os_str() == entry) + path.starts_with(entry_path) } diff --git a/crates/forge/src/cli/test.rs b/crates/forge/src/cli/test.rs index cbbc5045..871ff200 100644 --- a/crates/forge/src/cli/test.rs +++ b/crates/forge/src/cli/test.rs @@ -2,10 +2,13 @@ use anyhow::Result; use clap::Parser; use console::style; use std::net::TcpListener; -use std::path::Path; +use std::path::{Path, PathBuf}; use std::process::Stdio; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; use tokio::process::Command; +use uuid::Uuid; use super::frontend_target::FrontendTarget; use super::ui; @@ -113,6 +116,9 @@ impl TestCommand { } async fn run_frontend_tests(&self) -> Result { + // Install a ctrl-c watcher that flips a shared shutdown flag. ContainerGuard's + // Drop fires on early return, but we still need to interrupt long-running waits. + // `tokio::select!` on `signal::ctrl_c()` in the wait loops covers that. let frontend_dir = Path::new("frontend"); if !frontend_dir.exists() { println!(); @@ -163,47 +169,56 @@ impl TestCommand { let db_name = read_db_name(); println!(" {} Starting PostgreSQL...", ui::step()); let (pg_container, pg_port) = start_postgres(&db_name).await?; + // Container guard now owns teardown. ANY return path (?, early bail, + // panic, SIGINT after select wakeup) triggers `docker rm -f`. + let (_pg_guard, pg_armed) = ContainerGuard::new(pg_container); let db_url = format!("postgres://postgres:forge@localhost:{pg_port}/{db_name}"); + println!( + " {} DATABASE_URL: {}", + ui::info(), + mask_database_url_for_display(&db_url) + ); - let binary = match build_project(frontend_type).await { - Ok(bin) => bin, - Err(e) => { - stop_postgres(&pg_container).await; - return Err(e); + // Race the rest of the flow against SIGINT so the container guard fires + // promptly instead of waiting for the Playwright child to drain. + let work = async { + let binary = build_project(frontend_type).await?; + let port = pick_random_port()?; + let app_url = format!("http://localhost:{port}"); + + println!(" {} Starting server on port {port}...", ui::step()); + let mut child = start_server(&binary, port, &db_url).await?; + + print!(" {} Waiting for server...", ui::step()); + if !wait_for_health(&app_url, Duration::from_secs(120)).await { + println!(" {}", style("timed out").red()); + kill_and_reap(&mut child).await; + anyhow::bail!( + "Server did not become healthy within 120s.\n\ + Check the binary output for errors." + ); } - }; + println!(" {}", style("ready").green()); - let port = pick_random_port()?; - let app_url = format!("http://localhost:{port}"); + let result = self.execute_frontend_tests(frontend_dir, &app_url).await; - println!(" {} Starting server on port {port}...", ui::step()); - let mut child = match start_server(&binary, port, &db_url).await { - Ok(child) => child, - Err(e) => { - stop_postgres(&pg_container).await; - return Err(e); - } + println!(); + println!(" {} Stopping server...", ui::step()); + kill_and_reap(&mut child).await; + result }; - print!(" {} Waiting for server...", ui::step()); - if !wait_for_health(&app_url, Duration::from_secs(120)).await { - println!(" {}", style("timed out").red()); - let _ = child.kill().await; - stop_postgres(&pg_container).await; - anyhow::bail!( - "Server did not become healthy within 120s.\n\ - Check the binary output for errors." - ); - } - println!(" {}", style("ready").green()); - - let result = self.execute_frontend_tests(frontend_dir, &app_url).await; - - println!(); - println!(" {} Stopping server...", ui::step()); - let _ = child.kill().await; - stop_postgres(&pg_container).await; + let result = tokio::select! { + r = work => r, + _ = tokio::signal::ctrl_c() => { + println!(); + println!(" {} SIGINT received, tearing down...", ui::warn()); + Err(anyhow::anyhow!("interrupted")) + } + }; + // Guard handles cleanup. Keep armed. + let _ = pg_armed; result } @@ -313,15 +328,46 @@ fn pick_random_port() -> Result { Ok(port) } +/// RAII guard that calls `docker rm -f` on the container when dropped. +/// +/// Best-effort: Drop runs synchronously and the runtime may already be torn +/// down, so we shell out blocking and ignore the result. Pairs with the +/// async-aware ctrl-c handler in `TestCommand::execute` to cover SIGINT. +struct ContainerGuard { + name: String, + armed: Arc, +} + +impl ContainerGuard { + fn new(name: String) -> (Self, Arc) { + let armed = Arc::new(AtomicBool::new(true)); + ( + Self { + name, + armed: armed.clone(), + }, + armed, + ) + } +} + +impl Drop for ContainerGuard { + fn drop(&mut self) { + if !self.armed.load(Ordering::SeqCst) { + return; + } + // Use blocking std::process — the tokio runtime might be gone by now. + let _ = std::process::Command::new("docker") + .args(["rm", "-f", &self.name]) + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .status(); + } +} + async fn start_postgres(db_name: &str) -> Result<(String, u16)> { - let container_name = format!( - "forge-test-pg-{}-{}", - std::process::id(), - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .map(|d| d.as_secs()) - .unwrap_or(0) - ); + // Uuid suffix avoids PID+epoch collisions on rapid reinvocation. + let container_name = format!("forge-test-pg-{}", Uuid::new_v4().simple()); let _ = Command::new("docker") .args(["rm", "-f", &container_name]) @@ -344,6 +390,9 @@ async fn start_postgres(db_name: &str) -> Result<(String, u16)> { &format!("POSTGRES_DB={db_name}"), "-p", "0:5432", + // Floating tag: pulls whatever `postgres:18` currently resolves to. + // Acceptable for ephemeral test containers; pin to a digest if you + // need reproducible CI builds across PG point releases. "postgres:18", ]) .stdout(Stdio::null()) @@ -400,13 +449,42 @@ async fn start_postgres(db_name: &str) -> Result<(String, u16)> { anyhow::bail!("PostgreSQL did not become ready within 30s") } -async fn stop_postgres(container_name: &str) { - let _ = Command::new("docker") - .args(["rm", "-f", container_name]) - .stdout(Stdio::null()) - .stderr(Stdio::null()) - .status() - .await; +/// Mask the password segment in a `postgres[ql]://user:password@host…` URL. +fn mask_database_url_for_display(url: &str) -> String { + let (scheme, rest) = match url.split_once("://") { + Some(p) => p, + None => return url.to_string(), + }; + let Some((userinfo, host)) = rest.rsplit_once('@') else { + return url.to_string(); + }; + let masked = match userinfo.split_once(':') { + Some((user, _pw)) => format!("{user}:***"), + None => userinfo.to_string(), + }; + format!("{scheme}://{masked}@{host}") +} + +/// Atomically replace a file via tempfile-in-parent + rename. Avoids leaving +/// the destination empty on crash between truncate and write. +fn atomic_write(path: &Path, contents: &[u8]) -> Result<()> { + let parent = path.parent().unwrap_or_else(|| Path::new(".")); + let tmp = tempfile::NamedTempFile::new_in(parent)?; + { + use std::io::Write; + tmp.as_file().write_all(contents)?; + tmp.as_file().sync_all()?; + } + let dest: PathBuf = path.to_path_buf(); + tmp.persist(&dest) + .map_err(|e| anyhow::anyhow!("atomic rename failed for {}: {}", path.display(), e))?; + Ok(()) +} + +/// Kill the child and wait for it so the OS reaps the zombie immediately. +async fn kill_and_reap(child: &mut tokio::process::Child) { + let _ = child.kill().await; + let _ = child.wait().await; } async fn build_project(frontend_type: Option) -> Result { @@ -420,6 +498,12 @@ async fn build_project(frontend_type: Option) -> Result) -> Result>() .join("\n"); - std::fs::write(frontend_env, patched)?; + atomic_write(frontend_env, patched.as_bytes())?; } // Build Dioxus WASM before cargo build: rust_embed requires real files in @@ -484,7 +568,7 @@ async fn build_project(frontend_type: Option) -> Result Result<()> { mod tests { use super::*; - fn default_cmd() -> TestCommand { - TestCommand { - skip_backend: false, - skip_frontend: false, - ui: false, - headed: false, - args: vec![], - } - } - - #[test] - fn test_command_default_runs_both() { - let cmd = default_cmd(); - assert!(!cmd.skip_backend); - assert!(!cmd.skip_frontend); - } - - #[test] - fn test_command_skip_backend() { - let cmd = TestCommand { - skip_backend: true, - ..default_cmd() - }; - assert!(cmd.skip_backend); - assert!(!cmd.skip_frontend); - } - - #[test] - fn test_command_skip_frontend() { - let cmd = TestCommand { - skip_frontend: true, - ..default_cmd() - }; - assert!(!cmd.skip_backend); - assert!(cmd.skip_frontend); - } - - #[test] - fn test_command_with_ui_and_args() { - let cmd = TestCommand { - ui: true, - args: vec!["tests/todo.spec.ts".into()], - ..default_cmd() - }; - assert!(cmd.ui); - assert_eq!(cmd.args.len(), 1); - } - - #[test] - fn test_command_headed() { - let cmd = TestCommand { - headed: true, - ..default_cmd() - }; - assert!(cmd.headed); - } - - #[test] - fn test_read_db_name_default() { - assert!(!read_db_name().is_empty()); - } - #[test] fn test_pick_random_port() { let port1 = pick_random_port().unwrap(); diff --git a/crates/forge/src/cli/webhook.rs b/crates/forge/src/cli/webhook.rs index 3068dce4..25104e79 100644 --- a/crates/forge/src/cli/webhook.rs +++ b/crates/forge/src/cli/webhook.rs @@ -26,8 +26,8 @@ struct ReplayArgs { webhook_name: String, /// Idempotency key of the event to replay. idempotency_key: String, - /// Base URL of the running forge server (default: http://localhost:3000). - #[arg(long, default_value = "http://localhost:3000")] + /// Base URL of the running forge server (default: http://localhost:9081). + #[arg(long, default_value = "http://localhost:9081")] base_url: String, } @@ -47,6 +47,12 @@ struct ListArgs { impl WebhookCommand { /// Execute the webhook subcommand. pub async fn execute(self) -> Result<()> { + // Webhook subcommands rely on cwd-relative `forge.toml`; anchor at project root. + if let Err(e) = super::project_root::enter_project_root() { + return Err(forge_core::ForgeError::config(format!( + "must be run from inside a forge project: {e}" + ))); + } match self.command { WebhookSubcommand::Replay(args) => replay(args).await, WebhookSubcommand::List(args) => list(args).await, @@ -54,10 +60,28 @@ impl WebhookCommand { } } +/// Mirror the runtime's URL resolution: prefer `DATABASE_URL` env var, then +/// fall back to `[database].url` from forge.toml. +fn resolve_database_url(config: &forge_core::config::ForgeConfig) -> Result { + if let Ok(url) = std::env::var("DATABASE_URL") + && !url.is_empty() + { + return Ok(url); + } + let url = config.database.url(); + if url.is_empty() { + return Err(forge_core::ForgeError::config( + "no database URL configured: set DATABASE_URL or [database].url in forge.toml", + )); + } + Ok(url.to_string()) +} + #[allow(clippy::disallowed_methods)] async fn replay(args: ReplayArgs) -> Result<()> { let config = forge_core::config::ForgeConfig::from_file("forge.toml")?; - let pool = sqlx::PgPool::connect(&config.database.url) + let db_url = resolve_database_url(&config)?; + let pool = sqlx::PgPool::connect(&db_url) .await .map_err(forge_core::ForgeError::Database)?; @@ -103,17 +127,6 @@ async fn replay(args: ReplayArgs) -> Result<()> { body.len() ); - // Delete the existing idempotency record so the replay isn't rejected - sqlx::query( - "DELETE FROM forge_webhook_events \ - WHERE webhook_name = $1 AND idempotency_key = $2", - ) - .bind(&args.webhook_name) - .bind(&args.idempotency_key) - .execute(&pool) - .await - .map_err(forge_core::ForgeError::Database)?; - let client = reqwest::Client::new(); let mut request = client.post(format!( "{}/webhooks/{}", @@ -142,6 +155,28 @@ async fn replay(args: ReplayArgs) -> Result<()> { let status_code = response.status(); let reason = status_code.canonical_reason().unwrap_or_default(); println!("Response: {} {}", status_code.as_u16(), reason); + + // Only clear the dedup record on 2xx so a failed replay doesn't allow + // the original event to silently re-execute as a brand-new delivery. + if status_code.is_success() { + if let Err(e) = sqlx::query( + "DELETE FROM forge_webhook_events \ + WHERE webhook_name = $1 AND idempotency_key = $2", + ) + .bind(&args.webhook_name) + .bind(&args.idempotency_key) + .execute(&pool) + .await + { + eprintln!("Warning: replay succeeded but failed to clear dedup record: {e}"); + } + } else { + eprintln!( + "Replay did not succeed (HTTP {}); dedup record preserved.", + status_code.as_u16() + ); + } + let body = response .text() .await @@ -156,7 +191,8 @@ async fn replay(args: ReplayArgs) -> Result<()> { #[allow(clippy::disallowed_methods)] async fn list(args: ListArgs) -> Result<()> { let config = forge_core::config::ForgeConfig::from_file("forge.toml")?; - let pool = sqlx::PgPool::connect(&config.database.url) + let db_url = resolve_database_url(&config)?; + let pool = sqlx::PgPool::connect(&db_url) .await .map_err(forge_core::ForgeError::Database)?; @@ -200,14 +236,15 @@ async fn list(args: ListArgs) -> Result<()> { ); for (webhook, key, status, processed_at, has_body) in &rows { let replay = if *has_body { "yes" } else { "no" }; + let key_display: String = if key.chars().count() > 28 { + key.chars().take(28).collect() + } else { + key.clone() + }; println!( "{:<20} {:<30} {:<10} {:<24} {}", webhook, - if key.len() > 28 { - key.get(..28).unwrap_or_default() - } else { - key.as_str() - }, + key_display, status, processed_at.format("%Y-%m-%d %H:%M:%S UTC"), replay diff --git a/crates/forge/src/runtime/builder.rs b/crates/forge/src/runtime/builder.rs index 29a56f6c..f219d857 100644 --- a/crates/forge/src/runtime/builder.rs +++ b/crates/forge/src/runtime/builder.rs @@ -57,6 +57,9 @@ pub struct ForgeBuilder { pub(super) frontend_handler: Option, #[cfg(feature = "gateway")] pub(super) custom_routes_factory: Option Router + Send + Sync>>, + /// Deferred error from `auto_register()` so the builder stays chainable. + /// Surfaced from `build()`. + pub(super) auto_register_error: Option, } impl ForgeBuilder { @@ -84,6 +87,7 @@ impl ForgeBuilder { frontend_handler: None, #[cfg(feature = "gateway")] custom_routes_factory: None, + auto_register_error: None, } } @@ -194,7 +198,9 @@ impl ForgeBuilder { #[cfg(feature = "gateway")] mcp_tools: std::mem::take(&mut self.mcp_registry), }; - crate::auto_register::auto_register_all(&mut registries); + if let Err(e) = crate::auto_register::auto_register_all(&mut registries) { + self.auto_register_error = Some(e); + } self.function_registry = registries.functions; #[cfg(feature = "jobs")] { @@ -322,6 +328,10 @@ impl ForgeBuilder { } pub fn build(self) -> Result { + if let Some(err) = self.auto_register_error { + return Err(err); + } + let config = self .config .ok_or_else(|| ForgeError::config("Configuration is required"))?; diff --git a/crates/forge/src/runtime/mod.rs b/crates/forge/src/runtime/mod.rs index dd763e3e..bf84a172 100644 --- a/crates/forge/src/runtime/mod.rs +++ b/crates/forge/src/runtime/mod.rs @@ -5,7 +5,7 @@ pub use builder::ForgeBuilder; #[cfg(feature = "gateway")] use std::future::Future; -use std::net::IpAddr; +use std::net::{IpAddr, Ipv4Addr}; use std::path::PathBuf; #[cfg(feature = "gateway")] use std::pin::Pin; @@ -47,7 +47,7 @@ use forge_runtime::pg::{LeaderConfig, LeaderElection, PgNotifyBus}; use forge_core::CircuitBreakerClient; #[cfg(feature = "gateway")] use forge_runtime::gateway::{ - AuthConfig, GatewayConfig as RuntimeGatewayConfig, GatewayServer, TlsListenConfig, + AuthConfig, GatewayConfig as RuntimeGatewayConfig, GatewayServer, PeerAddr, TlsListenConfig, bind_listener, }; #[cfg(feature = "jobs")] @@ -282,9 +282,9 @@ impl Forge { // HOST env var overrides bind address; PORT env var overrides config port. let ip_address: IpAddr = std::env::var("HOST") - .unwrap_or_else(|_| "0.0.0.0".to_string()) - .parse() - .unwrap_or_else(|_| "0.0.0.0".parse().expect("valid IP literal")); + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(IpAddr::V4(Ipv4Addr::UNSPECIFIED)); if let Ok(port_str) = std::env::var("PORT") && let Ok(port) = port_str.parse::() @@ -873,7 +873,8 @@ impl Forge { gateway = gateway .with_signals_collector(collector) .with_signals_anonymize_ip(self.config.signals.anonymize_ip) - .with_signals_geoip(geoip); + .with_signals_geoip(geoip) + .with_signals_rate_limit_per_minute(self.config.signals.rate_limit_per_minute); forge_runtime::signals::session::spawn_session_reaper( signals_pool.clone(), @@ -1054,7 +1055,17 @@ impl Forge { return; } }; - let serve = axum::serve(listener, router).with_graceful_shutdown(async move { + // Serve with per-connection peer address so downstream + // middleware can resolve the real client IP. Without this the + // router's default `into_make_service` omits `ConnectInfo`, + // leaving every client IP unresolved — which collapses per-IP + // rate-limit buckets, blanks signal visitor IDs, and breaks the + // IP-bound SSE auth ticket. Mirrors `GatewayServer::run`. + let serve = axum::serve( + listener, + router.into_make_service_with_connect_info::(), + ) + .with_graceful_shutdown(async move { let _ = gateway_shutdown_rx.wait_for(|v| *v).await; tracing::debug!("Gateway draining in-flight requests"); }); diff --git a/docs/docs/build/scheduled-tasks.mdx b/docs/docs/build/scheduled-tasks.mdx index 4e0fc0b6..de1614bc 100644 --- a/docs/docs/build/scheduled-tasks.mdx +++ b/docs/docs/build/scheduled-tasks.mdx @@ -47,7 +47,7 @@ Coordination relies on PostgreSQL advisory locks and unique constraints rather t | Attribute | Type | Default | Description | |-----------|------|---------|-------------| | Schedule | `"cron expression"` | required | Standard cron (5-part: `m h d M w`) or extended (6-part with seconds) | -| `timezone` | `"tz"` | `"UTC"` | IANA timezone (e.g., `"America/New_York"`, `"Europe/London"`) | +| `timezone` | `"tz"` | `"UTC"` | IANA timezone (e.g., `"America/New_York"`, `"Europe/London"`). Validated at compile time; an unknown tz fails with `Invalid timezone: "X". Must be an IANA tz database name (e.g., "UTC", "America/New_York").` | | `timeout` | `"duration"` | `"1h"` | Maximum execution time (`"30s"`, `"5m"`, `"1h"`) | | `catch_up` | flag | `false` | Run missed executions after downtime | | `catch_up_limit` | `u32` | `10` | Maximum catch-up runs per tick | diff --git a/docs/docs/reference/admin-api.mdx b/docs/docs/reference/admin-api.mdx index 9301b556..07d669f7 100644 --- a/docs/docs/reference/admin-api.mdx +++ b/docs/docs/reference/admin-api.mdx @@ -43,6 +43,31 @@ Errors from admin endpoints use a flat envelope, not the RPC envelope: | `POST` | `/_api/admin/queues/{name}/resume` | Resume a paused queue | | `GET` | `/_api/admin/nodes` | List cluster nodes | | `GET` | `/_api/admin/leaders` | List current leader elections | +| `POST` | `/_api/admin/sessions/{session_id}/revoke` | Revoke a session's cached auth and evict its SSE connection | + +--- + +## Sessions + +### Revoke a session + +```http +POST /_api/admin/sessions/{session_id}/revoke +Authorization: Bearer +Content-Type: application/json + +{ "reason": "user role demoted" } +``` + +The reactor caches the `AuthContext` at subscribe time and re-runs subscribed queries until the JWT's `exp`. Server-side changes that happen before `exp` — role demotion, tenant move, force-logout from another device — are otherwise invisible. Calling this endpoint drops the reactor's cached context for the session, evicts the live SSE connection, and unsubscribes all queries, jobs, and workflows tied to that session. The client must reconnect and re-subscribe with a fresh token. + +Response: + +```json +{ "status": "revoked" } +``` + +Returns `503 reactor_unavailable` if the reactor isn't running on this node (e.g. headless migrate-only invocations). --- diff --git a/docs/docs/ship/configuration.mdx b/docs/docs/ship/configuration.mdx index 0218fef5..93e19388 100644 --- a/docs/docs/ship/configuration.mdx +++ b/docs/docs/ship/configuration.mdx @@ -189,6 +189,7 @@ openssl rand -base64 32 | `jwt_leeway` | string | `"60s"` | Clock-skew tolerance for `exp`/`nbf` validation | | `access_token_ttl` | string | `"1h"` | Access token lifetime | | `refresh_token_ttl` | string | `"30d"` | Refresh token lifetime | +| `refresh_cookie` | bool | `true` | Browser clients keep the refresh token in an `HttpOnly; Secure; SameSite=Strict` cookie. See [SSE & refresh-token security](./security#sse-authentication). | | `session_ttl` | string | `"7d"` | Session TTL | `jwt_secret` must be at least 32 bytes when HMAC algorithms are used. Generate a suitable value with `openssl rand -base64 32`. @@ -472,6 +473,7 @@ Built-in product analytics and frontend diagnostics. Enabled by default. See [Si | `flush_interval_ms` | u64 | `5000` | Max milliseconds between event buffer flushes | | `excluded_functions` | string[] | `[]` | Function names to skip in auto-capture | | `bot_detection` | bool | `true` | Tag bot traffic via User-Agent pattern matching | +| `rate_limit_per_minute`| u32 | `600` | Per-IP cap on `/signal` requests in a 60s window | ```toml [signals] diff --git a/docs/docs/ship/security.mdx b/docs/docs/ship/security.mdx index d1713a50..86074405 100644 --- a/docs/docs/ship/security.mdx +++ b/docs/docs/ship/security.mdx @@ -104,6 +104,38 @@ Two backends are available via `[rate_limit]` in `forge.toml`: Auth failures and rate-limit rejections are captured as signals events (`auth.failed`, `rate_limit.exceeded`) so Grafana dashboards surface attack patterns automatically. +## SSE authentication + +Browsers' `EventSource` API cannot set custom headers, so authenticated SSE streams use one-shot **tickets** instead of putting the JWT in the URL. Query strings leak into access logs, browser history, and `Referer` headers — tickets bound those leaks to a 30-second, single-use, IP-pinned token that grants nothing else. + +The generated TypeScript and Dioxus clients handle this automatically: + +1. Client `POST /_api/events/ticket` with `Authorization: Bearer `. +2. Server returns `{ ticket, expires_in_secs }`. Ticket is held in process memory only. +3. Client opens `GET /_api/events?ticket=`. The server consumes the ticket atomically and rejects on IP mismatch. + +Native (non-browser) clients can still send `Authorization` directly on the `GET /_api/events` request and skip the ticket fetch. Anonymous SSE connects without any credential. + +## Refresh tokens in browsers + +Set `[auth] refresh_cookie = true` (default) and serve the refresh from an `HttpOnly; Secure; SameSite=Strict` cookie. JS-reachable storage (localStorage, sessionStorage) is XSS-stealable: any third-party script can read it and mint long-lived access tokens out-of-band. Your `refresh` mutation owns the cookie lifecycle: + +```rust +#[forge::mutation] +pub async fn refresh(ctx: &MutationContext) -> Result { + // Read refresh token from the request cookie (the framework's auth + // middleware does not parse this cookie — it's specific to your endpoint). + let refresh = ctx.request.cookie("forge_refresh") + .ok_or_else(|| ForgeError::unauthorized("missing refresh token"))?; + let pair = ctx.issue_token_pair(/* ... */).await?; + // Set the new HttpOnly cookie on the response. Cookie helpers live in + // your handler; the framework does not bake a refresh endpoint in. + Ok(pair) +} +``` + +Disable `refresh_cookie` (`= false`) only if the refresh endpoint cannot share a registrable domain with the frontend, or for native clients that talk to the API across origins. In that case, persist the refresh token in OS-keyring-equivalent storage (forge-dioxus native uses `dirs::data_local_dir()`). + ## CORS CORS is opt-in. Enable it and list allowed origins in `forge.toml`: diff --git a/docs/skills/forge-idiomatic-engineer/references/api.md b/docs/skills/forge-idiomatic-engineer/references/api.md index 8884950d..325f19fb 100644 --- a/docs/skills/forge-idiomatic-engineer/references/api.md +++ b/docs/skills/forge-idiomatic-engineer/references/api.md @@ -68,7 +68,7 @@ Defines a task that runs on a recurring schedule. Execution is guaranteed to hap | `schedule = "0 9 * * *"` | Named form of the positional cron expression. | | `every = "5m"` | Sugar for simple interval schedules. Converts to a cron expression internally. | | `daily_at = "03:00"` | Sugar for daily schedules at a specific time. | -| `timezone = "UTC"` | Sets the schedule's timezone. | +| `timezone = "UTC"` | Sets the schedule's timezone. Compile-time validated against the IANA tz database (`chrono_tz`). An unknown value fails with `Invalid timezone: "X". Must be an IANA tz database name (e.g., "UTC", "America/New_York").` | | `group = "default"` | Groups crons for concurrency management. | | `timeout = "1h"` | Sets the maximum allowed execution time. | | `catch_up` | Executes missed intervals if the system was offline. **Default limit: 10 catch-up executions**. | @@ -216,6 +216,11 @@ jwt_audience = "https://api.example.com" # required by default (audience_requir # secret = "${OLD_JWT_SECRET}" # valid_until = "2026-06-01T00:00:00Z" +# Browser clients store refresh tokens in an HttpOnly Secure cookie by default +# (XSS cannot read them). Set false only when the refresh endpoint cannot share +# a registrable domain with the frontend, or for legacy non-browser clients. +# refresh_cookie = true + # AuthConfig::dev_mode() / AuthMiddleware::permissive() return Result and refuse # construction when FORGE_ENV=production. Don't ship dev-mode auth to prod. @@ -279,6 +284,7 @@ batch_size = 100 # events per batch INSERT flush_interval_ms = 5000 # max milliseconds between flushes excluded_functions = [] # function names to skip from auto-capture bot_detection = true # tag bot traffic via UA patterns +rate_limit_per_minute = 600 # per-IP cap on /signal in a rolling 60s window # GeoIP: embedded DB-IP Country Lite resolves IPs to country codes automatically (zero config) geoip_db_path = "" # optional: path to MaxMind GeoLite2-City.mmdb for city-level resolution @@ -299,7 +305,7 @@ key_path = "${GATEWAY_TLS_KEY_PATH}" ### Signal Endpoint -A single `POST /_api/signal` endpoint accepts a discriminated payload via the top-level `type` field. The server short-circuits `event` and `view` payloads when the request carries `DNT: 1` or `Sec-GPC: 1`. Crash reports (`type: "report"`) still land so production errors from DNT users don't disappear. +A single `POST /_api/signal` endpoint accepts a discriminated payload via the top-level `type` field. The server short-circuits `event` and `view` payloads when the request carries `DNT: 1` or `Sec-GPC: 1`. Crash reports (`type: "report"`) still land so production errors from DNT users don't disappear. When signals are disabled (the default) the endpoint still returns `204 No Content` and drops the body, so the always-on client trackers (web vitals, page views) never produce console errors against a missing route. | `type` | Payload | Purpose | |---|---|---| @@ -367,6 +373,7 @@ All `/_api/admin/*` routes require the `admin` role on `AuthContext`. Every stat | `POST` | `/_api/admin/queues/{name}/resume` | Remove the queue from `forge_paused_queues`. | | `GET` | `/_api/admin/nodes` | List `forge_nodes` rows with status, heartbeat, load metrics. | | `GET` | `/_api/admin/leaders` | Current advisory-lock holders per leader role. | +| `POST` | `/_api/admin/sessions/{session_id}/revoke` | Server-side auth revocation: drops cached `AuthContext` on the reactor and evicts the SSE connection. Body: `{ "reason": "..." }`. Wire to identity-system revocation events (demotion, tenant move, force-logout). | State-changing routes accept an optional `reason` string; pass it — the audit log is searched after incidents. @@ -430,7 +437,17 @@ builder.custom_routes(|pool| { - Returned router is merged into the gateway's `/_api` namespace, so `/export/csv` is reachable at `/_api/export/csv`. - Full middleware applies automatically: JWT auth, CORS, tracing, concurrency limits, request timeouts. - Handlers read `Extension` to access the authenticated user. Unauthenticated requests still arrive with an unauthenticated context — check `auth.user_id()` if login is required. -- Avoid paths that conflict with built-ins: `/health`, `/ready`, `/rpc`, `/rpc/*`, `/events`, `/subscribe`, `/unsubscribe`, `/subscribe-job`, `/subscribe-workflow`, `/signal`, `/webhooks/*`, `/mcp`, `/oauth/*`. Conflicts panic at startup. +- Avoid paths that conflict with built-ins: `/health`, `/ready`, `/rpc`, `/rpc/*`, `/events`, `/events/ticket`, `/subscribe`, `/unsubscribe`, `/subscribe-job`, `/subscribe-workflow`, `/signal`, `/webhooks/*`, `/mcp`, `/oauth/*`. Conflicts panic at startup. + +## SSE Authentication + +Authenticated SSE streams use one-shot tickets so the JWT never appears in the URL: + +1. Client `POST /_api/events/ticket` with `Authorization: Bearer `. +2. Server returns `{ "ticket": "", "expires_in_secs": 30 }`. +3. Client opens `GET /_api/events?ticket=`. + +Tickets are single-use, expire after 30s, are bound to the resolved client IP, and live only in process memory (no DB row). A `Authorization` header on `GET /_api/events` is also accepted for clients that can set headers (native, server-to-server). Anonymous SSE connects without any ticket. The generated TypeScript and Dioxus clients perform the ticket fetch automatically. ## API Versioning diff --git a/examples/with-dioxus/demo/.env b/examples/with-dioxus/demo/.env index be175897..71462349 100644 --- a/examples/with-dioxus/demo/.env +++ b/examples/with-dioxus/demo/.env @@ -1,21 +1,17 @@ -# Server +# Dev-only environment for `forge test` and local runs. NOT shipped to users: +# `scripts/build-template-archive.sh` excludes `.env`, and the webhook secret is +# used server-side only (never in the browser bundle). Users copy `.env.example` +# and generate their own secrets. Mirrors the realtime-todo-list convention. HOST=0.0.0.0 PORT=9081 - -# Logging (error, warn, info, debug, trace) RUST_LOG=info,forge_runtime::function::executor=trace - -# Postgres container settings POSTGRES_USER=postgres POSTGRES_PASSWORD=forge POSTGRES_DB=forge_dioxus_demo_template POSTGRES_PORT=5432 - -# JWT secret for authentication -JWT_SECRET=demo-jwt-secret-change-me-in-production - -# Webhook HMAC secret (must match client-side secret) +JWT_SECRET=dev-jwt-secret-not-for-production-use-please-rotate +JWT_AUDIENCE=forge-demo-dev WEBHOOK_SECRET=demo-secret - -# Enable offline mode for sqlx compile-time checks +SEED_DEMO_USER=true +CORS_ORIGIN=http://localhost:9080 SQLX_OFFLINE=true diff --git a/examples/with-dioxus/demo/.env.example b/examples/with-dioxus/demo/.env.example index be175897..b5e0fd5d 100644 --- a/examples/with-dioxus/demo/.env.example +++ b/examples/with-dioxus/demo/.env.example @@ -1,3 +1,5 @@ +# Copy to `.env` and fill in real values. Never commit `.env`. + # Server HOST=0.0.0.0 PORT=9081 @@ -11,11 +13,22 @@ POSTGRES_PASSWORD=forge POSTGRES_DB=forge_dioxus_demo_template POSTGRES_PORT=5432 -# JWT secret for authentication -JWT_SECRET=demo-jwt-secret-change-me-in-production +# JWT signing secret. Generate with: openssl rand -base64 32 +JWT_SECRET=CHANGE_ME_USE_OPENSSL_RAND_BASE64_32 + +# JWT audience claim. Must match the audience configured in your auth provider. +JWT_AUDIENCE=CHANGE_ME_YOUR_AUDIENCE + +# HMAC secret used to verify inbound webhook signatures. +# Generate with: openssl rand -hex 32 +WEBHOOK_SECRET=CHANGE_ME_USE_OPENSSL_RAND_HEX_32 + +# Seed the demo user (demo@example.com / password123) at first migration. +# DEV ONLY. Leave unset (or `false`) in any deployed environment. +SEED_DEMO_USER=true -# Webhook HMAC secret (must match client-side secret) -WEBHOOK_SECRET=demo-secret +# CORS origin for the Dioxus frontend. Override per environment. +CORS_ORIGIN=http://localhost:9080 # Enable offline mode for sqlx compile-time checks SQLX_OFFLINE=true diff --git a/examples/with-dioxus/demo/.gitignore b/examples/with-dioxus/demo/.gitignore index 6c4eb93f..e017b398 100644 --- a/examples/with-dioxus/demo/.gitignore +++ b/examples/with-dioxus/demo/.gitignore @@ -13,6 +13,8 @@ frontend/playwright-report/ frontend/test-results/ # Environment +# `.env` is tracked: it holds dev-only secrets for `forge test` and local runs. +# Real deployments use their own secrets; the template archive excludes `.env`. .env.local .env.*.local diff --git a/examples/with-dioxus/demo/Cargo.toml b/examples/with-dioxus/demo/Cargo.toml index 43d88d0d..fa961d5a 100644 --- a/examples/with-dioxus/demo/Cargo.toml +++ b/examples/with-dioxus/demo/Cargo.toml @@ -6,8 +6,9 @@ rust-version = "1.92" publish = false [features] -default = ["embedded-frontend"] +default = [] embedded-frontend = ["dep:rust-embed", "forge/embedded-frontend"] +testcontainers = ["forge/testcontainers"] [dependencies] forge = { workspace = true } @@ -24,6 +25,9 @@ tokio-tungstenite = { version = "0.26", features = ["rustls-tls-webpki-roots"] } futures-util = "0.3" argon2 = "0.5" password-hash = "0.5" +hmac = "0.12" +sha2 = "0.10" +hex = "0.4" rust-embed = { version = "8", optional = true } [build-dependencies] diff --git a/examples/with-dioxus/demo/Dockerfile b/examples/with-dioxus/demo/Dockerfile index e9f03715..69dcb57f 100644 --- a/examples/with-dioxus/demo/Dockerfile +++ b/examples/with-dioxus/demo/Dockerfile @@ -1,9 +1,9 @@ -FROM rust:1.92 AS dev +FROM rust:1.92-slim-bookworm AS dev WORKDIR /app RUN cargo install cargo-watch --locked RUN apt-get update && apt-get install -y curl pkg-config libssl-dev && rm -rf /var/lib/apt/lists/* -FROM rust:1.92 AS frontend-builder +FROM rust:1.92-slim-bookworm AS frontend-builder WORKDIR /app/frontend RUN rustup target add wasm32-unknown-unknown RUN cargo install dioxus-cli --version 0.7.3 --locked @@ -11,7 +11,7 @@ COPY frontend/Cargo.toml frontend/Dioxus.toml ./ COPY frontend/src ./src RUN dx build --web --release -FROM rust:1.92 AS builder +FROM rust:1.92-slim-bookworm AS builder WORKDIR /app RUN apt-get update && apt-get install -y pkg-config libssl-dev && rm -rf /var/lib/apt/lists/* @@ -24,7 +24,7 @@ COPY --from=frontend-builder /app/frontend/dist ./frontend/dist RUN cargo build --release -FROM debian:bookworm-slim AS runtime +FROM debian:bookworm-20250203-slim AS runtime RUN apt-get update && apt-get install -y ca-certificates && rm -rf /var/lib/apt/lists/* WORKDIR /app COPY --from=builder /app/target/release/forge-dioxus-demo-template /app/forge-dioxus-demo-template diff --git a/examples/with-dioxus/demo/docker-compose.yml b/examples/with-dioxus/demo/docker-compose.yml index 47bf375a..81e79910 100644 --- a/examples/with-dioxus/demo/docker-compose.yml +++ b/examples/with-dioxus/demo/docker-compose.yml @@ -6,7 +6,7 @@ services: target: dev working_dir: /workspace/examples/with-dioxus/demo ports: - - "9081:9081" + - "127.0.0.1:9081:9081" env_file: - .env environment: @@ -44,7 +44,7 @@ services: otel: build: ../../../docker/otel-lgtm ports: - - "3000:3000" + - "127.0.0.1:3000:3000" env_file: - .env environment: diff --git a/examples/with-dioxus/demo/forge.toml b/examples/with-dioxus/demo/forge.toml index ab9dfad8..ad88bc51 100644 --- a/examples/with-dioxus/demo/forge.toml +++ b/examples/with-dioxus/demo/forge.toml @@ -15,7 +15,7 @@ url = "${DATABASE_URL}" [gateway] port = 9081 cors_enabled = true -cors_origins = ["http://localhost:9080", "http://127.0.0.1:9080"] +cors_origins = ["${CORS_ORIGIN-http://localhost:9080}", "http://127.0.0.1:9080"] # request_timeout = "30s" # max_body_size = "10mb" # quiet_paths = ["/_api/health", "/_api/ready"] # Routes excluded from traces/metrics/logs @@ -42,7 +42,7 @@ otlp_endpoint = "http://localhost:4318" [auth] jwt_algorithm = "HS256" jwt_secret = "${JWT_SECRET}" -jwt_audience = "${JWT_AUDIENCE-https://api.forge-demo.local}" +jwt_audience = "${JWT_AUDIENCE}" [mcp] enabled = true @@ -52,9 +52,13 @@ session_ttl = "1h" allowed_origins = ["http://localhost:9080", "http://127.0.0.1:9080"] require_protocol_version_header = true -# [rate_limit] -# mode = "local" # local, distributed -# max_local_buckets = 10000 +[rate_limit] +# hybrid: per-node DashMap fast path, PG fallback for global keys (DDoS-grade). +# strict: every check round-trips to PG (cluster-wide correct, billing-grade). +mode = "hybrid" +max_local_buckets = 10000 +# Per-handler quotas live on the function macros, e.g. +# #[forge::mutation(rate_limit_requests = 10, rate_limit_per_secs = 60, rate_limit_key = "ip")] # [cluster] # name = "node-1" # auto-generated if omitted @@ -66,3 +70,8 @@ require_protocol_version_header = true # [node] # roles = ["gateway", "function", "worker", "scheduler"] # worker_capabilities = ["general"] # general, gpu, high_cpu + +[signals] +# Product analytics + diagnostics are off by default; this demo opts in to +# exercise the /_api/signal endpoint and the client SDK. +enabled = true diff --git a/examples/with-dioxus/demo/frontend/playwright.config.ts b/examples/with-dioxus/demo/frontend/playwright.config.ts index a6685492..7a839a3a 100644 --- a/examples/with-dioxus/demo/frontend/playwright.config.ts +++ b/examples/with-dioxus/demo/frontend/playwright.config.ts @@ -6,7 +6,7 @@ export default defineConfig({ testDir: "./tests", fullyParallel: false, forbidOnly: !!process.env.CI, - retries: process.env.CI ? 1 : 1, + retries: process.env.CI ? 2 : 0, timeout: 180_000, workers: process.env.CI ? 1 : undefined, reporter: "html", diff --git a/examples/with-dioxus/demo/frontend/src/components/auth_card.rs b/examples/with-dioxus/demo/frontend/src/components/auth_card.rs index ab11a886..935ddccf 100644 --- a/examples/with-dioxus/demo/frontend/src/components/auth_card.rs +++ b/examples/with-dioxus/demo/frontend/src/components/auth_card.rs @@ -13,9 +13,23 @@ pub fn AuthCard() -> Element { let signals = use_signals(); let mut mode = use_signal(|| "login".to_string()); - let mut auth_email = use_signal(|| "demo@example.com".to_string()); - let mut auth_password = use_signal(|| "password123".to_string()); - let mut auth_name = use_signal(|| String::new()); + // Prefill credentials only in debug builds. Release WASM ships empty fields so a + // public demo is not a one-click login when combined with the seeded admin user. + let mut auth_email = use_signal(|| { + if cfg!(debug_assertions) { + "demo@example.com".to_string() + } else { + String::new() + } + }); + let mut auth_password = use_signal(|| { + if cfg!(debug_assertions) { + "password123".to_string() + } else { + String::new() + } + }); + let mut auth_name = use_signal(String::new); let mut auth_error = use_signal(|| None::); let mut loading = use_signal(|| false); @@ -99,7 +113,10 @@ pub fn AuthCard() -> Element { auth_error.set(None); match refresh_mut.call(RefreshInput::new(&rt)).await { Ok(pair) => { - signals.track_with_properties("token_refresh", json!({"count": refresh_count() + 1})); + signals.track_with_properties( + "token_refresh", + json!({"count": refresh_count() + 1}), + ); let claims = parse_jwt_claims(&pair.access_token); token_claims.set(Some(claims)); auth.update_tokens( @@ -134,13 +151,14 @@ pub fn AuthCard() -> Element { // Restore viewer on mount (persisted in localStorage by ForgeAuthProvider) use_effect(move || { - if auth.is_authenticated() && auth_user.read().is_none() { - if let Some(viewer) = auth.viewer::() { - if let Some(token) = auth.access_token() { - token_claims.set(Some(parse_jwt_claims(&token))); - } - auth_user.set(Some(viewer)); + if auth.is_authenticated() + && auth_user.read().is_none() + && let Some(viewer) = auth.viewer::() + { + if let Some(token) = auth.access_token() { + token_claims.set(Some(parse_jwt_claims(&token))); } + auth_user.set(Some(viewer)); } }); diff --git a/examples/with-dioxus/demo/frontend/src/components/cache_card.rs b/examples/with-dioxus/demo/frontend/src/components/cache_card.rs index d56f7ace..3732ed98 100644 --- a/examples/with-dioxus/demo/frontend/src/components/cache_card.rs +++ b/examples/with-dioxus/demo/frontend/src/components/cache_card.rs @@ -36,15 +36,12 @@ pub fn CacheCard() -> Element { spawn(async move { loading.set(true); let start = now_ms(); - match forge::get_demo_stats(&client).await { - Ok(stats) => { - let elapsed = now_ms() - start; - signals.track_with_properties("cache_fetch", json!({"response_ms": elapsed, "cache_hit": elapsed < 100.0, "fetch_number": fetch_count() + 1})); - data.set(Some(stats)); - response_ms.set(Some(elapsed)); - fetch_count.set(fetch_count() + 1); - } - Err(_) => {} + if let Ok(stats) = forge::get_demo_stats(&client).await { + let elapsed = now_ms() - start; + signals.track_with_properties("cache_fetch", json!({"response_ms": elapsed, "cache_hit": elapsed < 100.0, "fetch_number": fetch_count() + 1})); + data.set(Some(stats)); + response_ms.set(Some(elapsed)); + fetch_count.set(fetch_count() + 1); } loading.set(false); }); diff --git a/examples/with-dioxus/demo/frontend/src/components/users_section.rs b/examples/with-dioxus/demo/frontend/src/components/users_section.rs index d519f828..8afb7ba0 100644 --- a/examples/with-dioxus/demo/frontend/src/components/users_section.rs +++ b/examples/with-dioxus/demo/frontend/src/components/users_section.rs @@ -4,11 +4,30 @@ use serde_json::json; use crate::forge::{ CreateUserParams, DeleteUserParams, UpdateUserParams, User, use_create_user, use_delete_user, - use_get_users_subscription, use_update_user, + use_forge_auth, use_get_users_subscription, use_update_user, }; +/// `get_users` requires an authenticated session, so only mount the subscribing +/// inner component once logged in. Subscribing while anonymous would fire a 401 +/// on every page load (Dioxus hooks can't be called conditionally, so the gate +/// has to live at the component boundary). #[component] pub fn UsersSection(selected_user: Signal>) -> Element { + let auth = use_forge_auth(); + rsx! { + if auth.is_authenticated() { + UsersSectionInner { selected_user } + } else { + section { class: "card", + h2 { "Users " span { class: "badge green", "crud + subscribe" } } + p { class: "muted", "Log in to manage users." } + } + } + } +} + +#[component] +fn UsersSectionInner(selected_user: Signal>) -> Element { let create_user = use_create_user(); let update_user = use_update_user(); let delete_user = use_delete_user(); diff --git a/examples/with-dioxus/demo/frontend/src/components/webhook_card.rs b/examples/with-dioxus/demo/frontend/src/components/webhook_card.rs index 3d5b365e..b93126ed 100644 --- a/examples/with-dioxus/demo/frontend/src/components/webhook_card.rs +++ b/examples/with-dioxus/demo/frontend/src/components/webhook_card.rs @@ -1,14 +1,15 @@ use dioxus::prelude::*; use forge_dioxus::use_signals; -use hmac::{Hmac, Mac}; use serde_json::json; -use sha2::Sha256; use super::{format_time, generate_key}; -use crate::forge::use_get_webhook_events_subscription; +use crate::forge::{ + TriggerDemoWebhookInput, trigger_demo_webhook, use_forge_client, + use_get_webhook_events_subscription, +}; #[component] -pub fn WebhookCard(api_url: String) -> Element { +pub fn WebhookCard() -> Element { let signals = use_signals(); let state = use_get_webhook_events_subscription(); let events = state.data.clone().unwrap_or_default(); @@ -17,6 +18,8 @@ pub fn WebhookCard(api_url: String) -> Element { let mut key_used = use_signal(|| false); let mut webhook_error = use_signal(|| None::); + let client = use_forge_client(); + rsx! { section { class: "card", h2 { "Webhook " span { class: "badge", "webhook" } } @@ -42,23 +45,27 @@ pub fn WebhookCard(api_url: String) -> Element { } button { disabled: key_used(), onclick: { - let api_url = api_url.clone(); let signals = signals.clone(); + let client = client.clone(); move |_| { if key_used() { return; } webhook_error.set(None); let key = idempotency_key(); - let api_url = api_url.clone(); let signals = signals.clone(); + let client = client.clone(); spawn(async move { - match trigger_webhook(&api_url, &key).await { - Ok(()) => { + // The HMAC secret lives on the server. The backend signs + // and POSTs the webhook to itself so the WASM bundle + // never ships the secret. + let input = TriggerDemoWebhookInput::new(key.clone()); + match trigger_demo_webhook(&client, input).await { + Ok(_) => { signals.track_with_properties("webhook_sent", json!({"idempotency_key": &key})); key_used.set(true); } - Err(msg) => { + Err(e) => { signals.track("webhook_error"); - webhook_error.set(Some(msg)); + webhook_error.set(Some(e.to_string())); } } }); @@ -87,53 +94,3 @@ pub fn WebhookCard(api_url: String) -> Element { } } } - -async fn trigger_webhook(api_url: &str, idempotency_key: &str) -> Result<(), String> { - #[cfg(target_arch = "wasm32")] - let now = js_sys::Date::now(); - #[cfg(not(target_arch = "wasm32"))] - let now = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_millis() as f64; - let payload = serde_json::json!({ "action": "test", "ts": now }).to_string(); - - let mut mac = Hmac::::new_from_slice(b"demo-secret").map_err(|e| e.to_string())?; - mac.update(payload.as_bytes()); - let signature = hex::encode(mac.finalize().into_bytes()); - - // HMAC-SHA256 webhooks enforce a replay window: the server rejects any - // request whose `X-Webhook-Timestamp` (unix seconds) is missing or outside - // the 300s window. Send it alongside the signature. - let timestamp = (now / 1000.0) as i64; - - // In same-origin builds `api_url` is empty. Unlike browser `fetch`, reqwest - // can't parse a relative URL, so resolve it against the current origin. - #[cfg(target_arch = "wasm32")] - let base = if api_url.is_empty() { - web_sys::window() - .and_then(|w| w.location().origin().ok()) - .unwrap_or_default() - } else { - api_url.to_string() - }; - #[cfg(not(target_arch = "wasm32"))] - let base = api_url.to_string(); - - let resp = reqwest::Client::new() - .post(format!("{base}/_api/webhooks/demo")) - .header("Content-Type", "application/json") - .header("X-Webhook-Signature", signature) - .header("X-Webhook-Timestamp", timestamp.to_string()) - .header("X-Idempotency-Key", idempotency_key) - .body(payload) - .send() - .await - .map_err(|e| e.to_string())?; - - if resp.status().is_success() { - Ok(()) - } else { - Err(format!("Error: {}", resp.status().as_u16())) - } -} diff --git a/examples/with-dioxus/demo/frontend/src/forge/api.rs b/examples/with-dioxus/demo/frontend/src/forge/api.rs index 6729583d..5fa4545f 100644 --- a/examples/with-dioxus/demo/frontend/src/forge/api.rs +++ b/examples/with-dioxus/demo/frontend/src/forge/api.rs @@ -192,6 +192,16 @@ pub async fn register( pub fn use_register() -> Mutation { use_forge_mutation("register") } +pub async fn trigger_demo_webhook( + client: &ForgeClient, + args: TriggerDemoWebhookInput, +) -> Result { + client.call("trigger_demo_webhook", args).await +} + +pub fn use_trigger_demo_webhook() -> Mutation { + use_forge_mutation("trigger_demo_webhook") +} #[derive(Debug, Clone, PartialEq, serde::Serialize)] pub struct UpdateUserParams { pub id: String, diff --git a/examples/with-dioxus/demo/frontend/src/forge/types.rs b/examples/with-dioxus/demo/frontend/src/forge/types.rs index d0342e7a..de63438e 100644 --- a/examples/with-dioxus/demo/frontend/src/forge/types.rs +++ b/examples/with-dioxus/demo/frontend/src/forge/types.rs @@ -1,11 +1,6 @@ // @generated by FORGE - DO NOT EDIT -#![allow( - dead_code, - unused_imports, - clippy::redundant_field_names, - clippy::too_many_arguments -)] +#![allow(dead_code, unused_imports, clippy::too_many_arguments)] use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -24,34 +19,34 @@ impl AuthResponse { Self { access_token: access_token.into(), refresh_token: refresh_token.into(), - user: user, + user, } } } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct BinanceTrade { - pub symbol: String, - pub price: String, - pub quantity: String, - pub trade_time: i64, - pub is_buyer_maker: bool, + pub s: String, + pub p: String, + pub q: String, + pub T: i64, + pub m: bool, } impl BinanceTrade { pub fn new( - symbol: impl Into, - price: impl Into, - quantity: impl Into, - trade_time: i64, - is_buyer_maker: bool, + s: impl Into, + p: impl Into, + q: impl Into, + T: i64, + m: bool, ) -> Self { Self { - symbol: symbol.into(), - price: price.into(), - quantity: quantity.into(), - trade_time: trade_time, - is_buyer_maker: is_buyer_maker, + s: s.into(), + p: p.into(), + q: q.into(), + T, + m, } } } @@ -85,9 +80,9 @@ impl DemoStats { computed_at: impl Into, ) -> Self { Self { - total_users: total_users, - total_trades: total_trades, - total_webhooks: total_webhooks, + total_users, + total_trades, + total_webhooks, computed_at: computed_at.into(), } } @@ -108,15 +103,15 @@ impl ExportInput { #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct ExportOutput { - pub count: i64, + pub count: usize, pub data: String, pub format: String, } impl ExportOutput { - pub fn new(count: i64, data: impl Into, format: impl Into) -> Self { + pub fn new(count: usize, data: impl Into, format: impl Into) -> Self { Self { - count: count, + count, data: data.into(), format: format.into(), } @@ -133,8 +128,8 @@ pub struct IssApiResponse { impl IssApiResponse { pub fn new(iss_position: IssPosition, timestamp: i64, message: impl Into) -> Self { Self { - iss_position: iss_position, - timestamp: timestamp, + iss_position, + timestamp, message: message.into(), } } @@ -159,8 +154,8 @@ impl IssLocation { ) -> Self { Self { id: id.into(), - latitude: latitude, - longitude: longitude, + latitude, + longitude, api_timestamp: api_timestamp.into(), created_at: created_at.into(), } @@ -229,7 +224,7 @@ impl McpUserInfo { id: id.into(), email: email.into(), name: name.into(), - role: role, + role, } } } @@ -292,15 +287,28 @@ impl Trade { Self { id: id.into(), symbol: symbol.into(), - price: price, - quantity: quantity, + price, + quantity, trade_time: trade_time.into(), - is_buyer_maker: is_buyer_maker, + is_buyer_maker, created_at: created_at.into(), } } } +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct TriggerDemoWebhookInput { + pub idempotency_key: String, +} + +impl TriggerDemoWebhookInput { + pub fn new(idempotency_key: impl Into) -> Self { + Self { + idempotency_key: idempotency_key.into(), + } + } +} + #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct User { pub id: String, @@ -309,7 +317,6 @@ pub struct User { pub role: UserRole, pub created_at: String, pub updated_at: String, - pub password_hash: Option, } impl User { @@ -325,17 +332,11 @@ impl User { id: id.into(), email: email.into(), name: name.into(), - role: role, + role, created_at: created_at.into(), updated_at: updated_at.into(), - password_hash: None, } } - - pub fn password_hash(mut self, password_hash: impl Into) -> Self { - self.password_hash = Some(password_hash.into()); - self - } } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -361,7 +362,7 @@ impl UserPublic { id: id.into(), email: email.into(), name: name.into(), - role: role, + role, created_at: created_at.into(), updated_at: updated_at.into(), } @@ -392,7 +393,7 @@ pub struct VerificationOutput { impl VerificationOutput { pub fn new(verified: bool, token: impl Into) -> Self { Self { - verified: verified, + verified, token: token.into(), } } diff --git a/examples/with-dioxus/demo/frontend/src/pages/demo.rs b/examples/with-dioxus/demo/frontend/src/pages/demo.rs index 321d1f88..9c7bcd25 100644 --- a/examples/with-dioxus/demo/frontend/src/pages/demo.rs +++ b/examples/with-dioxus/demo/frontend/src/pages/demo.rs @@ -23,7 +23,7 @@ pub fn DemoPage() -> Element { div { class: "col", TradesCard {} AuthCard {} - WebhookCard { api_url: API_URL.to_string() } + WebhookCard {} VerificationCard { selected_user } } } diff --git a/examples/with-dioxus/demo/frontend/src/signals_bridge.rs b/examples/with-dioxus/demo/frontend/src/signals_bridge.rs index 954d6a45..02128c58 100644 --- a/examples/with-dioxus/demo/frontend/src/signals_bridge.rs +++ b/examples/with-dioxus/demo/frontend/src/signals_bridge.rs @@ -124,11 +124,12 @@ fn install_window_bridge(signals: forge_dioxus::ForgeSignals) { obj.insert("stack".to_string(), serde_json::Value::String(stack)); } } - let context = if ctx_val.is_object() && ctx_val.as_object().is_some_and(|o| !o.is_empty()) { - Some(ctx_val) - } else { - None - }; + let context = + if ctx_val.is_object() && ctx_val.as_object().is_some_and(|o| !o.is_empty()) { + Some(ctx_val) + } else { + None + }; signals.capture_error(&*message, context); }, ) diff --git a/examples/with-dioxus/demo/frontend/tests/fixtures.ts b/examples/with-dioxus/demo/frontend/tests/fixtures.ts index 8c312780..3435b927 100644 --- a/examples/with-dioxus/demo/frontend/tests/fixtures.ts +++ b/examples/with-dioxus/demo/frontend/tests/fixtures.ts @@ -6,7 +6,7 @@ export const API_URL = process.env.FORGE_TEST_URL || process.env.VITE_API_URL || "http://localhost:9081"; -export const ACTION_TIMEOUT = process.env.CI ? 30_000 : 30_000; +export const ACTION_TIMEOUT = 30_000; export function uniqueId(prefix: string): string { return `${prefix}-${Date.now()}-${Math.random().toString(36).slice(2, 8)}`; diff --git a/examples/with-dioxus/demo/frontend/tests/home.spec.ts b/examples/with-dioxus/demo/frontend/tests/home.spec.ts index e4799ca5..80cd2394 100644 --- a/examples/with-dioxus/demo/frontend/tests/home.spec.ts +++ b/examples/with-dioxus/demo/frontend/tests/home.spec.ts @@ -1,3 +1,4 @@ +import type { Page } from "@playwright/test"; import { test, expect, @@ -7,6 +8,28 @@ import { trackConsoleErrors, } from "./fixtures"; +// The release WASM bundle ships empty credential fields (prefill is debug-only), +// so fill them explicitly. Logging in rotates the token: the client tears down +// the anonymous SSE stream and re-subscribes every query over a fresh +// authenticated one. Wait for that re-subscription so reactive reads (and +// job/workflow push updates) reflect the authenticated session. +async function loginAsAdmin(page: Page) { + const auth = page.locator("section", { + has: page.getByText("refresh tokens"), + }); + await auth.getByPlaceholder("Email").fill("demo@example.com"); + await auth.getByPlaceholder(/Password/).fill("password123"); + const resubscribed = page.waitForResponse( + (res) => res.url().includes("/_api/subscribe") && res.status() === 200, + { timeout: ACTION_TIMEOUT * 3 }, + ); + await auth.locator('button[type="submit"]').click(); + await expect(auth.getByText("Logged in as")).toBeVisible({ + timeout: ACTION_TIMEOUT, + }); + await resubscribed; +} + async function signDemoWebhook(body: string): Promise { const encoder = new TextEncoder(); const keyData = await crypto.subtle.importKey( @@ -32,6 +55,7 @@ test("users CRUD stays reactive through create, edit, and delete", async ({ const updatedName = uniqueId("Edited"); await gotoReady(); + await loginAsAdmin(page); const section = page.locator("section", { has: page.getByRole("heading", { name: /users/i }), @@ -67,6 +91,7 @@ test("export job and verification workflow complete from the UI", async ({ gotoReady, }) => { await gotoReady(); + await loginAsAdmin(page); const exportSection = page.locator("section", { has: page.getByText("Export Job"), @@ -110,6 +135,8 @@ test("auth flow logs in, refreshes, and logs out cleanly", async ({ has: page.getByText("refresh tokens"), }); + await section.getByPlaceholder("Email").fill("demo@example.com"); + await section.getByPlaceholder(/Password/).fill("password123"); await section.locator('button[type="submit"]').click(); await expect(section.getByText("Logged in as")).toBeVisible({ timeout: ACTION_TIMEOUT, @@ -159,6 +186,7 @@ test("webhook endpoint rejects bad signatures and surfaces accepted events", asy expect(accepted.ok()).toBeTruthy(); await gotoReady(); + await loginAsAdmin(page); const section = page.locator("section", { has: page.getByText("Webhook"), }); diff --git a/examples/with-dioxus/demo/migrations/0001_initial.sql b/examples/with-dioxus/demo/migrations/0001_initial.sql index a854effb..66f90ec4 100644 --- a/examples/with-dioxus/demo/migrations/0001_initial.sql +++ b/examples/with-dioxus/demo/migrations/0001_initial.sql @@ -1,6 +1,10 @@ -CREATE TYPE user_role AS ENUM ('admin', 'member', 'guest'); +DO $$ BEGIN + CREATE TYPE user_role AS ENUM ('admin', 'member', 'guest'); +EXCEPTION + WHEN duplicate_object THEN NULL; +END $$; -CREATE TABLE users ( +CREATE TABLE IF NOT EXISTS users ( id UUID PRIMARY KEY, email VARCHAR(255) NOT NULL, name VARCHAR(255) NOT NULL, @@ -10,9 +14,9 @@ CREATE TABLE users ( updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); -CREATE UNIQUE INDEX idx_users_email ON users(email); +CREATE UNIQUE INDEX IF NOT EXISTS idx_users_email ON users(email); -CREATE TABLE iss_location ( +CREATE TABLE IF NOT EXISTS iss_location ( id UUID PRIMARY KEY, latitude DOUBLE PRECISION NOT NULL, longitude DOUBLE PRECISION NOT NULL, @@ -20,7 +24,7 @@ CREATE TABLE iss_location ( created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); -CREATE TABLE trades ( +CREATE TABLE IF NOT EXISTS trades ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), symbol VARCHAR(20) NOT NULL, price DOUBLE PRECISION NOT NULL, @@ -30,9 +34,9 @@ CREATE TABLE trades ( created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); -CREATE INDEX idx_trades_created_at ON trades(created_at DESC); +CREATE INDEX IF NOT EXISTS idx_trades_created_at ON trades(created_at DESC); -CREATE TABLE webhook_events ( +CREATE TABLE IF NOT EXISTS webhook_events ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), idempotency_key VARCHAR(255) NOT NULL, webhook_name VARCHAR(100) NOT NULL, @@ -40,7 +44,7 @@ CREATE TABLE webhook_events ( processed_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); -CREATE INDEX idx_webhook_events_processed_at ON webhook_events(processed_at DESC); +CREATE INDEX IF NOT EXISTS idx_webhook_events_processed_at ON webhook_events(processed_at DESC); SELECT forge_enable_reactivity('users'); SELECT forge_enable_reactivity('iss_location'); @@ -48,7 +52,7 @@ SELECT forge_enable_reactivity('trades'); SELECT forge_enable_reactivity('webhook_events'); -- Stats snapshot table for cached query demo -CREATE TABLE demo_stats ( +CREATE TABLE IF NOT EXISTS demo_stats ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), total_users INTEGER NOT NULL DEFAULT 0, total_trades INTEGER NOT NULL DEFAULT 0, @@ -56,14 +60,18 @@ CREATE TABLE demo_stats ( computed_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); --- Sample user for demo (password: "password123") +-- Demo admin user (password: "password123"). Idempotent on re-run. +-- IMPORTANT: This is a known-credential account for the demo only. +-- For production deployments, delete this seed block before running migrations +-- and create your first admin via a separate one-off script with a strong password. INSERT INTO users (id, email, name, role, password_hash, created_at, updated_at) VALUES ( 'a1b2c3d4-e5f6-4a5b-8c9d-0e1f2a3b4c5d', 'demo@example.com', 'Demo User', - 'member', + 'admin', '$argon2id$v=19$m=19456,t=2,p=1$AjozmE60AjazLA3S4LXuvw$v+Jo+M5NZ+Q1K4ro1pDS4Hx0/cnHJ3uvmJC7RiNJkUg', NOW(), NOW() -); +) +ON CONFLICT (id) DO UPDATE SET role = EXCLUDED.role; diff --git a/examples/with-dioxus/demo/src/functions/auth.rs b/examples/with-dioxus/demo/src/functions/auth.rs index acfcd4fd..4c0ab42c 100644 --- a/examples/with-dioxus/demo/src/functions/auth.rs +++ b/examples/with-dioxus/demo/src/functions/auth.rs @@ -20,7 +20,12 @@ pub struct RefreshInput { } async fn auth_response(ctx: &MutationContext, user: &User) -> Result { - let pair = ctx.issue_token_pair(user.id, &["user"]).await?; + let role = match user.role { + UserRole::Admin => "admin", + UserRole::Member => "user", + UserRole::Guest => "guest", + }; + let pair = ctx.issue_token_pair(user.id, &[role]).await?; Ok(AuthResponse { access_token: pair.access_token, refresh_token: pair.refresh_token, diff --git a/examples/with-dioxus/demo/src/functions/export.rs b/examples/with-dioxus/demo/src/functions/export.rs index 365fe3b5..9c5fd87b 100644 --- a/examples/with-dioxus/demo/src/functions/export.rs +++ b/examples/with-dioxus/demo/src/functions/export.rs @@ -14,7 +14,12 @@ pub struct ExportOutput { pub format: String, } -/// Export users as CSV or JSON with progress reporting +/// Export users as CSV or JSON with progress reporting. +/// +/// The `tokio::time::sleep` calls below are SIMULATED work — they exist solely so the +/// progress UI is visible in the demo. Replace them with real I/O (S3 puts, large +/// DB scans, format conversion) in production code. Never ship sleep-padded jobs: +/// they pin worker slots and inflate p99 for no value. #[forge::job( timeout = "5m", priority = "low", diff --git a/examples/with-dioxus/demo/src/functions/iss.rs b/examples/with-dioxus/demo/src/functions/iss.rs index 247d5bec..49875212 100644 --- a/examples/with-dioxus/demo/src/functions/iss.rs +++ b/examples/with-dioxus/demo/src/functions/iss.rs @@ -47,7 +47,7 @@ pub async fn iss_location(ctx: &CronContext) -> Result<()> { let response = ctx .http() - .get("http://api.open-notify.org/iss-now.json") + .get("https://api.open-notify.org/iss-now.json") .send() .await .map_err(|e| ForgeError::internal(format!("HTTP request failed: {}", e)))?; @@ -69,8 +69,16 @@ pub async fn iss_location(ctx: &CronContext) -> Result<()> { tracing::warn!(message = %data.message, "ISS API non-success"); } - let latitude: f64 = data.iss_position.latitude.parse().unwrap_or(0.0); - let longitude: f64 = data.iss_position.longitude.parse().unwrap_or(0.0); + let latitude: f64 = data + .iss_position + .latitude + .parse() + .map_err(|e| ForgeError::Deserialization(format!("invalid latitude: {e}")))?; + let longitude: f64 = data + .iss_position + .longitude + .parse() + .map_err(|e| ForgeError::Deserialization(format!("invalid longitude: {e}")))?; sqlx::query!( "INSERT INTO iss_location (id, latitude, longitude, api_timestamp, created_at) \ diff --git a/examples/with-dioxus/demo/src/functions/mcp.rs b/examples/with-dioxus/demo/src/functions/mcp.rs index 31d7c945..5a0d77e3 100644 --- a/examples/with-dioxus/demo/src/functions/mcp.rs +++ b/examples/with-dioxus/demo/src/functions/mcp.rs @@ -30,10 +30,10 @@ pub async fn mcp_me(ctx: &McpToolContext) -> forge::forge_core::Result forge::forge_core::Result> { + let _ = ctx.user_id()?; let mut conn = ctx.conn().await?; let users = sqlx::query_as!( @@ -63,13 +63,13 @@ pub struct McpGetUserInput { name = "demo.get_user_by_email", title = "Get User by Email", description = "Look up a single user by their email address", - public, read_only )] pub async fn mcp_get_user_by_email( ctx: &McpToolContext, input: McpGetUserInput, ) -> forge::forge_core::Result> { + let _ = ctx.user_id()?; let mut conn = ctx.conn().await?; let user = sqlx::query_as!( diff --git a/examples/with-dioxus/demo/src/functions/stats.rs b/examples/with-dioxus/demo/src/functions/stats.rs index 907ac11f..0eb5e8d3 100644 --- a/examples/with-dioxus/demo/src/functions/stats.rs +++ b/examples/with-dioxus/demo/src/functions/stats.rs @@ -3,6 +3,8 @@ use forge::prelude::*; #[forge::query(cache = "10s", auth = "none")] pub async fn get_demo_stats(ctx: &QueryContext) -> Result { + // Simulated work to make the `cache = "10s"` demo visible to a human watching the UI. + // Real handlers must not call sleep — it pins a worker thread for no useful work. tokio::time::sleep(std::time::Duration::from_millis(500)).await; let row = sqlx::query!( diff --git a/examples/with-dioxus/demo/src/functions/trades.rs b/examples/with-dioxus/demo/src/functions/trades.rs index 1b1792f7..bdf94df3 100644 --- a/examples/with-dioxus/demo/src/functions/trades.rs +++ b/examples/with-dioxus/demo/src/functions/trades.rs @@ -69,35 +69,53 @@ pub async fn trade_stream(ctx: &DaemonContext) -> Result<()> { msg = read.next() => { match msg { Some(Ok(Message::Text(text))) => { - if let Ok(trade) = serde_json::from_str::(&text) { - let price: f64 = trade.price.parse().unwrap_or(0.0); - let quantity: f64 = trade.quantity.parse().unwrap_or(0.0); - let trade_time = chrono::DateTime::from_timestamp_millis(trade.trade_time) + match serde_json::from_str::(&text) { + Ok(trade) => { + let price: f64 = trade.price.parse().map_err(|e| { + ForgeError::Deserialization(format!("invalid trade price: {e}")) + })?; + let quantity: f64 = trade.quantity.parse().map_err(|e| { + ForgeError::Deserialization(format!( + "invalid trade quantity: {e}" + )) + })?; + let trade_time = chrono::DateTime::from_timestamp_millis( + trade.trade_time, + ) .unwrap_or_else(Utc::now); - sqlx::query!( - "INSERT INTO trades (id, symbol, price, quantity, trade_time, is_buyer_maker, created_at) \ - VALUES (gen_random_uuid(), $1, $2, $3, $4, $5, NOW())", - &trade.symbol, - price, - quantity, - trade_time, - trade.is_buyer_maker - ) - .execute(ctx.db()) - .await - .ok(); + sqlx::query!( + "INSERT INTO trades (id, symbol, price, quantity, trade_time, is_buyer_maker, created_at) \ + VALUES (gen_random_uuid(), $1, $2, $3, $4, $5, NOW())", + &trade.symbol, + price, + quantity, + trade_time, + trade.is_buyer_maker + ) + .execute(ctx.db()) + .await?; + } + Err(e) => { + tracing::warn!("Skipping unparsable trade message: {e}"); + } } } Some(Ok(Message::Close(_))) => { - tracing::warn!("WebSocket closed by server"); - break; + return Err(ForgeError::internal( + "Binance WebSocket closed by server; daemon will restart", + )); } Some(Err(e)) => { - tracing::error!("WebSocket error: {}", e); - break; + return Err(ForgeError::internal(format!( + "Binance WebSocket error: {e}" + ))); + } + None => { + return Err(ForgeError::internal( + "Binance WebSocket stream ended; daemon will restart", + )); } - None => break, _ => {} } } diff --git a/examples/with-dioxus/demo/src/functions/users.rs b/examples/with-dioxus/demo/src/functions/users.rs index 9d7b90f2..33955888 100644 --- a/examples/with-dioxus/demo/src/functions/users.rs +++ b/examples/with-dioxus/demo/src/functions/users.rs @@ -1,9 +1,11 @@ use crate::schema::{User, UserRole}; use forge::prelude::*; -/// List all users with reactive subscription support -#[forge::query(cache = "30s", auth = "none")] +/// List all users with reactive subscription support. +/// Reading the global user list requires an authenticated session. +#[forge::query(cache = "30s", unscoped)] pub async fn get_users(ctx: &QueryContext) -> Result> { + let _ = ctx.user_id()?; sqlx::query_as!( User, r#" @@ -24,9 +26,10 @@ pub async fn get_users(ctx: &QueryContext) -> Result> { .map_err(Into::into) } -/// Get single user by ID -#[forge::query(timeout = "10s", auth = "none")] +/// Get single user by ID. Requires an authenticated session. +#[forge::query(timeout = "10s", unscoped)] pub async fn get_user(ctx: &QueryContext, id: Uuid) -> Result> { + let _ = ctx.user_id()?; sqlx::query_as!( User, r#" @@ -48,14 +51,15 @@ pub async fn get_user(ctx: &QueryContext, id: Uuid) -> Result> { .map_err(Into::into) } -/// Create a new user -#[forge::mutation(auth = "none")] +/// Create a new user. Requires the `admin` role. +#[forge::mutation(scope = "global")] pub async fn create_user( ctx: &MutationContext, email: String, name: String, role: Option, ) -> Result { + ctx.auth.require_role("admin")?; let id = Uuid::new_v4(); let now = Utc::now(); let role = role.unwrap_or_default(); @@ -87,8 +91,8 @@ pub async fn create_user( .map_err(Into::into) } -/// Update user with partial fields -#[forge::mutation(timeout = "30s", auth = "none")] +/// Update user with partial fields. Requires the `admin` role. +#[forge::mutation(timeout = "30s", scope = "global")] pub async fn update_user( ctx: &MutationContext, id: Uuid, @@ -96,6 +100,7 @@ pub async fn update_user( name: Option, role: Option, ) -> Result { + ctx.auth.require_role("admin")?; let mut conn = ctx.conn().await.map_err(ForgeError::Database)?; sqlx::query_as!( User, @@ -126,9 +131,10 @@ pub async fn update_user( .map_err(Into::into) } -/// Delete user by ID -#[forge::mutation(auth = "none")] +/// Delete user by ID. Requires the `admin` role. +#[forge::mutation(scope = "global")] pub async fn delete_user(ctx: &MutationContext, id: Uuid) -> Result { + ctx.auth.require_role("admin")?; let mut conn = ctx.conn().await.map_err(ForgeError::Database)?; let result = sqlx::query!("DELETE FROM users WHERE id = $1", id) .execute(&mut conn) @@ -136,3 +142,168 @@ pub async fn delete_user(ctx: &MutationContext, id: Uuid) -> Result { Ok(result.rows_affected() > 0) } + +#[cfg(all(test, feature = "testcontainers"))] +mod tests { + use super::*; + use forge::forge_core::function::{AuthContext, RequestMetadata}; + use forge::testing::{IsolatedTestDb, TestDatabase}; + use std::path::Path; + + async fn setup_db() -> IsolatedTestDb { + let base = TestDatabase::from_env().await.unwrap(); + let db = base.isolated("users_test").await.unwrap(); + db.run_sql(&forge::get_internal_sql()).await.unwrap(); + db.migrate(Path::new("migrations")).await.unwrap(); + db + } + + fn admin_auth() -> AuthContext { + AuthContext::authenticated(Uuid::new_v4(), vec!["admin".into()], Default::default()) + } + + fn query_ctx(pool: sqlx::PgPool) -> QueryContext { + QueryContext::new(pool, admin_auth(), RequestMetadata::default()) + } + + fn mutation_ctx(pool: sqlx::PgPool) -> MutationContext { + MutationContext::new(pool, admin_auth(), RequestMetadata::default()) + } + + #[tokio::test] + async fn test_create_user() { + let db = setup_db().await; + let ctx = mutation_ctx(db.pool().clone()); + + let user = create_user(&ctx, "test@example.com".into(), "Test User".into(), None) + .await + .unwrap(); + + assert_eq!(user.email, "test@example.com"); + assert_eq!(user.name, "Test User"); + assert_eq!(user.role, UserRole::default()); + db.cleanup().await.unwrap(); + } + + #[tokio::test] + async fn test_create_user_with_role() { + let db = setup_db().await; + let ctx = mutation_ctx(db.pool().clone()); + + let user = create_user( + &ctx, + "admin@example.com".into(), + "Admin".into(), + Some(UserRole::Admin), + ) + .await + .unwrap(); + + assert_eq!(user.role, UserRole::Admin); + db.cleanup().await.unwrap(); + } + + #[tokio::test] + async fn test_create_user_requires_admin_role() { + let db = setup_db().await; + let ctx = MutationContext::new( + db.pool().clone(), + AuthContext::authenticated(Uuid::new_v4(), vec!["member".into()], Default::default()), + RequestMetadata::default(), + ); + + let result = create_user(&ctx, "nope@example.com".into(), "No Admin".into(), None).await; + + assert!( + matches!(result, Err(ForgeError::Forbidden(_))), + "non-admin must be rejected with Forbidden, got {result:?}" + ); + db.cleanup().await.unwrap(); + } + + #[tokio::test] + async fn test_get_users() { + let db = setup_db().await; + let m_ctx = mutation_ctx(db.pool().clone()); + + create_user(&m_ctx, "a@test.com".into(), "User A".into(), None) + .await + .unwrap(); + create_user(&m_ctx, "b@test.com".into(), "User B".into(), None) + .await + .unwrap(); + + let q_ctx = query_ctx(db.pool().clone()); + let users = get_users(&q_ctx).await.unwrap(); + assert!(users.len() >= 2); + db.cleanup().await.unwrap(); + } + + #[tokio::test] + async fn test_get_user_by_id() { + let db = setup_db().await; + let m_ctx = mutation_ctx(db.pool().clone()); + + let created = create_user(&m_ctx, "find@test.com".into(), "Find Me".into(), None) + .await + .unwrap(); + + let q_ctx = query_ctx(db.pool().clone()); + let found = get_user(&q_ctx, created.id).await.unwrap(); + assert!(found.is_some()); + assert_eq!(found.unwrap().id, created.id); + db.cleanup().await.unwrap(); + } + + #[tokio::test] + async fn test_get_user_not_found() { + let db = setup_db().await; + let ctx = query_ctx(db.pool().clone()); + + let result = get_user(&ctx, Uuid::new_v4()).await.unwrap(); + assert!(result.is_none()); + db.cleanup().await.unwrap(); + } + + #[tokio::test] + async fn test_update_user() { + let db = setup_db().await; + let ctx = mutation_ctx(db.pool().clone()); + + let user = create_user(&ctx, "update@test.com".into(), "Original".into(), None) + .await + .unwrap(); + + let updated = update_user( + &ctx, + user.id, + Some("new@test.com".into()), + Some("Updated".into()), + None, + ) + .await + .unwrap(); + + assert_eq!(updated.email, "new@test.com"); + assert_eq!(updated.name, "Updated"); + db.cleanup().await.unwrap(); + } + + #[tokio::test] + async fn test_delete_user() { + let db = setup_db().await; + let ctx = mutation_ctx(db.pool().clone()); + + let user = create_user(&ctx, "delete@test.com".into(), "Delete Me".into(), None) + .await + .unwrap(); + + let deleted = delete_user(&ctx, user.id).await.unwrap(); + assert!(deleted); + + let q_ctx = query_ctx(db.pool().clone()); + let found = get_user(&q_ctx, user.id).await.unwrap(); + assert!(found.is_none()); + db.cleanup().await.unwrap(); + } +} diff --git a/examples/with-dioxus/demo/src/functions/verification.rs b/examples/with-dioxus/demo/src/functions/verification.rs index f4dd645c..91f6e7f3 100644 --- a/examples/with-dioxus/demo/src/functions/verification.rs +++ b/examples/with-dioxus/demo/src/functions/verification.rs @@ -112,14 +112,15 @@ pub struct ConfirmVerificationInput { pub workflow_id: String, } -#[forge::mutation(auth = "none")] // forge_workflow_events is owned by the runtime, so the framework user's .sqlx // cache doesn't see it. Runtime sqlx::query is the right tool here. +#[forge::mutation(tables("forge_workflow_events"), scope = "global")] #[allow(clippy::disallowed_methods)] pub async fn confirm_verification( ctx: &MutationContext, input: ConfirmVerificationInput, ) -> Result { + let _ = ctx.user_id()?; // Insert the confirmation event into the workflow events table. // The scheduler's NOTIFY trigger will wake the waiting workflow. sqlx::query( diff --git a/examples/with-dioxus/demo/src/functions/webhook.rs b/examples/with-dioxus/demo/src/functions/webhook.rs index 8fe6cef8..70468ad9 100644 --- a/examples/with-dioxus/demo/src/functions/webhook.rs +++ b/examples/with-dioxus/demo/src/functions/webhook.rs @@ -1,4 +1,6 @@ use forge::prelude::*; +use hmac::{Hmac, Mac}; +use sha2::Sha256; /// Webhook event record stored in database #[derive(Debug, Clone, serde::Serialize, serde::Deserialize, sqlx::FromRow)] @@ -59,3 +61,53 @@ pub async fn demo_webhook(ctx: &WebhookContext, payload: Value) -> Result Result { + let secret = ctx.env_require("WEBHOOK_SECRET")?; + let port: u16 = ctx.env_parse_or("PORT", 9081u16)?; + let payload = serde_json::json!({ + "action": "test", + "ts": Utc::now().timestamp_millis(), + }) + .to_string(); + + let mut mac = as Mac>::new_from_slice(secret.as_bytes()) + .map_err(|e| ForgeError::internal(format!("HMAC key init failed: {e}")))?; + mac.update(payload.as_bytes()); + let signature = hex::encode(mac.finalize().into_bytes()); + let timestamp = Utc::now().timestamp(); + + // Deliberate loopback call to this server's own webhook endpoint. The + // framework's `ctx.http()` client blocks private/loopback IPs (SSRF guard), + // so use a plain reqwest client for this intentional self-call. + let response = reqwest::Client::new() + .post(format!("http://127.0.0.1:{port}/_api/webhooks/demo")) + .header("Content-Type", "application/json") + .header("X-Webhook-Signature", signature) + .header("X-Webhook-Timestamp", timestamp.to_string()) + .header("X-Idempotency-Key", &input.idempotency_key) + .body(payload) + .send() + .await + .map_err(|e| ForgeError::internal(format!("Webhook self-call failed: {e}")))?; + + if !response.status().is_success() { + return Err(ForgeError::internal(format!( + "Webhook returned status {}", + response.status().as_u16() + ))); + } + + Ok(true) +} diff --git a/examples/with-dioxus/minimal/Dockerfile b/examples/with-dioxus/minimal/Dockerfile index 61905ba6..ac170b65 100644 --- a/examples/with-dioxus/minimal/Dockerfile +++ b/examples/with-dioxus/minimal/Dockerfile @@ -1,9 +1,9 @@ -FROM rust:1.92 AS dev +FROM rust:1.92-slim-bookworm AS dev WORKDIR /app RUN cargo install cargo-watch --locked RUN apt-get update && apt-get install -y curl pkg-config libssl-dev && rm -rf /var/lib/apt/lists/* -FROM rust:1.92 AS frontend-builder +FROM rust:1.92-slim-bookworm AS frontend-builder WORKDIR /app/frontend RUN rustup target add wasm32-unknown-unknown RUN cargo install dioxus-cli --version 0.7.3 --locked @@ -11,7 +11,7 @@ COPY frontend/Cargo.toml frontend/Dioxus.toml ./ COPY frontend/src ./src RUN dx build --web --release -FROM rust:1.92 AS builder +FROM rust:1.92-slim-bookworm AS builder WORKDIR /app RUN apt-get update && apt-get install -y pkg-config libssl-dev && rm -rf /var/lib/apt/lists/* @@ -24,7 +24,7 @@ COPY --from=frontend-builder /app/frontend/dist ./frontend/dist RUN cargo build --release -FROM debian:bookworm-slim AS runtime +FROM debian:bookworm-20250203-slim AS runtime RUN apt-get update && apt-get install -y ca-certificates && rm -rf /var/lib/apt/lists/* WORKDIR /app COPY --from=builder /app/target/release/forge-dioxus-minimal-template /app/forge-dioxus-minimal-template diff --git a/examples/with-dioxus/minimal/docker-compose.yml b/examples/with-dioxus/minimal/docker-compose.yml index 308334b6..46487659 100644 --- a/examples/with-dioxus/minimal/docker-compose.yml +++ b/examples/with-dioxus/minimal/docker-compose.yml @@ -5,7 +5,7 @@ services: dockerfile: Dockerfile target: dev ports: - - "9081:9081" + - "127.0.0.1:9081:9081" env_file: - .env environment: @@ -44,7 +44,7 @@ services: otel: build: ../../../docker/otel-lgtm ports: - - "3000:3000" + - "127.0.0.1:3000:3000" env_file: - .env environment: diff --git a/examples/with-dioxus/minimal/forge.toml b/examples/with-dioxus/minimal/forge.toml index db95d297..0edf5eb7 100644 --- a/examples/with-dioxus/minimal/forge.toml +++ b/examples/with-dioxus/minimal/forge.toml @@ -15,7 +15,7 @@ url = "${DATABASE_URL}" [gateway] port = 9081 cors_enabled = true -cors_origins = ["http://localhost:9080", "http://127.0.0.1:9080"] +cors_origins = ["${CORS_ORIGIN-http://localhost:9080}", "http://127.0.0.1:9080"] # request_timeout = "30s" # max_body_size = "10mb" # quiet_paths = ["/_api/health", "/_api/ready"] # Routes excluded from traces/metrics/logs @@ -48,9 +48,12 @@ otlp_endpoint = "http://localhost:4318" # --- RSA (RS256/RS384/RS512) - asymmetric, use for external providers --- # jwks_url = "" # Provider JWKS URLs: see auth reference docs -# [rate_limit] -# mode = "local" # local, distributed -# max_local_buckets = 10000 +[rate_limit] +# hybrid: per-node DashMap fast path (DDoS-grade). strict: PG round-trip (billing-grade). +mode = "hybrid" +max_local_buckets = 10000 +# Per-handler quotas live on the function macros, e.g. +# #[forge::mutation(rate_limit_requests = 10, rate_limit_per_secs = 60, rate_limit_key = "ip")] # [cluster] # name = "node-1" # auto-generated if omitted diff --git a/examples/with-dioxus/minimal/frontend/playwright.config.ts b/examples/with-dioxus/minimal/frontend/playwright.config.ts index a6685492..7a839a3a 100644 --- a/examples/with-dioxus/minimal/frontend/playwright.config.ts +++ b/examples/with-dioxus/minimal/frontend/playwright.config.ts @@ -6,7 +6,7 @@ export default defineConfig({ testDir: "./tests", fullyParallel: false, forbidOnly: !!process.env.CI, - retries: process.env.CI ? 1 : 1, + retries: process.env.CI ? 2 : 0, timeout: 180_000, workers: process.env.CI ? 1 : undefined, reporter: "html", diff --git a/examples/with-dioxus/minimal/frontend/src/forge/types.rs b/examples/with-dioxus/minimal/frontend/src/forge/types.rs index 58a2ca15..d0e29b99 100644 --- a/examples/with-dioxus/minimal/frontend/src/forge/types.rs +++ b/examples/with-dioxus/minimal/frontend/src/forge/types.rs @@ -1,10 +1,5 @@ // @generated by FORGE - DO NOT EDIT -#![allow( - dead_code, - unused_imports, - clippy::redundant_field_names, - clippy::too_many_arguments -)] +#![allow(dead_code, unused_imports, clippy::too_many_arguments)] use serde::{Deserialize, Serialize}; diff --git a/examples/with-dioxus/minimal/migrations/0001_initial.sql.example b/examples/with-dioxus/minimal/migrations/0001_initial.sql.example index efa755eb..3d416208 100644 --- a/examples/with-dioxus/minimal/migrations/0001_initial.sql.example +++ b/examples/with-dioxus/minimal/migrations/0001_initial.sql.example @@ -1,6 +1,4 @@ --- @up - --- Replace with your tables here +-- Migrations are forward-only. Add your schema changes below. -- Example: -- CREATE TABLE IF NOT EXISTS users ( -- id UUID PRIMARY KEY, @@ -9,10 +7,3 @@ -- created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() -- ); -- SELECT forge_enable_reactivity('users'); - --- @down - --- Add your rollback statements here --- Example: --- SELECT forge_disable_reactivity('users'); --- DROP TABLE IF EXISTS users; diff --git a/examples/with-dioxus/realtime-todo-list/.env b/examples/with-dioxus/realtime-todo-list/.env index 05b41b5e..87b0a353 100644 --- a/examples/with-dioxus/realtime-todo-list/.env +++ b/examples/with-dioxus/realtime-todo-list/.env @@ -11,8 +11,8 @@ POSTGRES_PASSWORD=forge POSTGRES_DB=todo-dioxus POSTGRES_PORT=5432 -# Optional: JWT secret for authentication -# FORGE_SECRET=your-secret-key-here +# JWT signing secret. Generate with: openssl rand -base64 32 +JWT_SECRET=dev-jwt-secret-not-for-production-use-please-rotate # Enable offline mode for sqlx compile-time checks SQLX_OFFLINE=true diff --git a/examples/with-dioxus/realtime-todo-list/.env.example b/examples/with-dioxus/realtime-todo-list/.env.example index bd7f2e73..4e5bf825 100644 --- a/examples/with-dioxus/realtime-todo-list/.env.example +++ b/examples/with-dioxus/realtime-todo-list/.env.example @@ -13,5 +13,5 @@ POSTGRES_PASSWORD=forge POSTGRES_DB=todo-dioxus POSTGRES_PORT=5432 -# Optional: JWT secret for authentication -# FORGE_SECRET=your-secret-key-here +# JWT signing secret. Generate with: openssl rand -base64 32 +JWT_SECRET=CHANGE_ME_USE_OPENSSL_RAND_BASE64_32 diff --git a/examples/with-dioxus/realtime-todo-list/.sqlx/query-0ef63404257f0212092be1612e8f63c641c19b14ceb671016d66310a74dbca26.json b/examples/with-dioxus/realtime-todo-list/.sqlx/query-0ef63404257f0212092be1612e8f63c641c19b14ceb671016d66310a74dbca26.json new file mode 100644 index 00000000..fd027d65 --- /dev/null +++ b/examples/with-dioxus/realtime-todo-list/.sqlx/query-0ef63404257f0212092be1612e8f63c641c19b14ceb671016d66310a74dbca26.json @@ -0,0 +1,46 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT * FROM todos WHERE user_id = $1 ORDER BY created_at DESC", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "user_id", + "type_info": "Uuid" + }, + { + "ordinal": 2, + "name": "title", + "type_info": "Text" + }, + { + "ordinal": 3, + "name": "completed", + "type_info": "Bool" + }, + { + "ordinal": 4, + "name": "created_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Uuid" + ] + }, + "nullable": [ + false, + false, + false, + false, + false + ] + }, + "hash": "0ef63404257f0212092be1612e8f63c641c19b14ceb671016d66310a74dbca26" +} diff --git a/.sqlx/query-183ad1d8316ef2ae5ac6ae4811b8a2bdbaeabbe137a871e26741a419a1aa5b19.json b/examples/with-dioxus/realtime-todo-list/.sqlx/query-2e465c3f5f3b3fb29f51cefabfab678ae2f60db0bfdbd437fb00618618b0ab5e.json similarity index 50% rename from .sqlx/query-183ad1d8316ef2ae5ac6ae4811b8a2bdbaeabbe137a871e26741a419a1aa5b19.json rename to examples/with-dioxus/realtime-todo-list/.sqlx/query-2e465c3f5f3b3fb29f51cefabfab678ae2f60db0bfdbd437fb00618618b0ab5e.json index fb994d1e..5b63fbe4 100644 --- a/.sqlx/query-183ad1d8316ef2ae5ac6ae4811b8a2bdbaeabbe137a871e26741a419a1aa5b19.json +++ b/examples/with-dioxus/realtime-todo-list/.sqlx/query-2e465c3f5f3b3fb29f51cefabfab678ae2f60db0bfdbd437fb00618618b0ab5e.json @@ -1,14 +1,15 @@ { "db_name": "PostgreSQL", - "query": "DELETE FROM todos WHERE id = $1", + "query": "DELETE FROM todos WHERE id = $1 AND user_id = $2", "describe": { "columns": [], "parameters": { "Left": [ + "Uuid", "Uuid" ] }, "nullable": [] }, - "hash": "183ad1d8316ef2ae5ac6ae4811b8a2bdbaeabbe137a871e26741a419a1aa5b19" + "hash": "2e465c3f5f3b3fb29f51cefabfab678ae2f60db0bfdbd437fb00618618b0ab5e" } diff --git a/examples/with-dioxus/realtime-todo-list/.sqlx/query-4aaff6edf11f4e43ee07cf8f58ebeb5479ac15d11cb1c4fdfd0cd7247f519f3c.json b/examples/with-dioxus/realtime-todo-list/.sqlx/query-4aaff6edf11f4e43ee07cf8f58ebeb5479ac15d11cb1c4fdfd0cd7247f519f3c.json new file mode 100644 index 00000000..e2ac3226 --- /dev/null +++ b/examples/with-dioxus/realtime-todo-list/.sqlx/query-4aaff6edf11f4e43ee07cf8f58ebeb5479ac15d11cb1c4fdfd0cd7247f519f3c.json @@ -0,0 +1,57 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO users (id, email, name, password_hash, created_at, updated_at)\n VALUES ($1, $2, $3, $4, $5, $6)\n RETURNING id, email, name, password_hash as \"password_hash!\", created_at, updated_at\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "email", + "type_info": "Varchar" + }, + { + "ordinal": 2, + "name": "name", + "type_info": "Varchar" + }, + { + "ordinal": 3, + "name": "password_hash!", + "type_info": "Text" + }, + { + "ordinal": 4, + "name": "created_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 5, + "name": "updated_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Uuid", + "Varchar", + "Varchar", + "Text", + "Timestamptz", + "Timestamptz" + ] + }, + "nullable": [ + false, + false, + false, + true, + false, + false + ] + }, + "hash": "4aaff6edf11f4e43ee07cf8f58ebeb5479ac15d11cb1c4fdfd0cd7247f519f3c" +} diff --git a/examples/with-dioxus/realtime-todo-list/.sqlx/query-6f810436b0b1e5e2e79b283ce91992a65a3d646156b70a30a3377a4e0f19f1f3.json b/examples/with-dioxus/realtime-todo-list/.sqlx/query-6f810436b0b1e5e2e79b283ce91992a65a3d646156b70a30a3377a4e0f19f1f3.json new file mode 100644 index 00000000..3e00e871 --- /dev/null +++ b/examples/with-dioxus/realtime-todo-list/.sqlx/query-6f810436b0b1e5e2e79b283ce91992a65a3d646156b70a30a3377a4e0f19f1f3.json @@ -0,0 +1,52 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT id, email, name, password_hash as \"password_hash!\", created_at, updated_at\n FROM users WHERE email = $1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "email", + "type_info": "Varchar" + }, + { + "ordinal": 2, + "name": "name", + "type_info": "Varchar" + }, + { + "ordinal": 3, + "name": "password_hash!", + "type_info": "Text" + }, + { + "ordinal": 4, + "name": "created_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 5, + "name": "updated_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Text" + ] + }, + "nullable": [ + false, + false, + false, + true, + false, + false + ] + }, + "hash": "6f810436b0b1e5e2e79b283ce91992a65a3d646156b70a30a3377a4e0f19f1f3" +} diff --git a/examples/with-svelte/realtime-todo-list/.sqlx/query-289c71ceebdcb32b1fa7de751cca0918c3286db00bfe90e56cdec7458e1e7b39.json b/examples/with-dioxus/realtime-todo-list/.sqlx/query-cd3a68eb363ca38467993dda8a6b904549a2f663613d7570b529eeaa13a6d9aa.json similarity index 68% rename from examples/with-svelte/realtime-todo-list/.sqlx/query-289c71ceebdcb32b1fa7de751cca0918c3286db00bfe90e56cdec7458e1e7b39.json rename to examples/with-dioxus/realtime-todo-list/.sqlx/query-cd3a68eb363ca38467993dda8a6b904549a2f663613d7570b529eeaa13a6d9aa.json index f5b84d9f..6ed720db 100644 --- a/examples/with-svelte/realtime-todo-list/.sqlx/query-289c71ceebdcb32b1fa7de751cca0918c3286db00bfe90e56cdec7458e1e7b39.json +++ b/examples/with-dioxus/realtime-todo-list/.sqlx/query-cd3a68eb363ca38467993dda8a6b904549a2f663613d7570b529eeaa13a6d9aa.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "INSERT INTO todos (title) VALUES ($1) RETURNING *", + "query": "INSERT INTO todos (user_id, title) VALUES ($1, $2) RETURNING *", "describe": { "columns": [ { @@ -10,22 +10,28 @@ }, { "ordinal": 1, + "name": "user_id", + "type_info": "Uuid" + }, + { + "ordinal": 2, "name": "title", "type_info": "Text" }, { - "ordinal": 2, + "ordinal": 3, "name": "completed", "type_info": "Bool" }, { - "ordinal": 3, + "ordinal": 4, "name": "created_at", "type_info": "Timestamptz" } ], "parameters": { "Left": [ + "Uuid", "Text" ] }, @@ -33,8 +39,9 @@ false, false, false, + false, false ] }, - "hash": "289c71ceebdcb32b1fa7de751cca0918c3286db00bfe90e56cdec7458e1e7b39" + "hash": "cd3a68eb363ca38467993dda8a6b904549a2f663613d7570b529eeaa13a6d9aa" } diff --git a/examples/with-dioxus/realtime-todo-list/.sqlx/query-c2eda736e5f6342831005dfbd5281fbeb29cb84e74ff483828f9e3ee0fcc517f.json b/examples/with-dioxus/realtime-todo-list/.sqlx/query-d1b4c05fc3f85f6412e22e47208952ef8f47e93152275994011c4382e85ed7f2.json similarity index 74% rename from examples/with-dioxus/realtime-todo-list/.sqlx/query-c2eda736e5f6342831005dfbd5281fbeb29cb84e74ff483828f9e3ee0fcc517f.json rename to examples/with-dioxus/realtime-todo-list/.sqlx/query-d1b4c05fc3f85f6412e22e47208952ef8f47e93152275994011c4382e85ed7f2.json index 1ebed98d..37c21eb7 100644 --- a/examples/with-dioxus/realtime-todo-list/.sqlx/query-c2eda736e5f6342831005dfbd5281fbeb29cb84e74ff483828f9e3ee0fcc517f.json +++ b/examples/with-dioxus/realtime-todo-list/.sqlx/query-d1b4c05fc3f85f6412e22e47208952ef8f47e93152275994011c4382e85ed7f2.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "UPDATE todos\n SET title = COALESCE($1, title),\n completed = COALESCE($2, completed)\n WHERE id = $3\n RETURNING *", + "query": "UPDATE todos\n SET title = COALESCE($1, title),\n completed = COALESCE($2, completed)\n WHERE id = $3 AND user_id = $4\n RETURNING *", "describe": { "columns": [ { @@ -10,16 +10,21 @@ }, { "ordinal": 1, + "name": "user_id", + "type_info": "Uuid" + }, + { + "ordinal": 2, "name": "title", "type_info": "Text" }, { - "ordinal": 2, + "ordinal": 3, "name": "completed", "type_info": "Bool" }, { - "ordinal": 3, + "ordinal": 4, "name": "created_at", "type_info": "Timestamptz" } @@ -28,6 +33,7 @@ "Left": [ "Text", "Bool", + "Uuid", "Uuid" ] }, @@ -35,8 +41,9 @@ false, false, false, + false, false ] }, - "hash": "c2eda736e5f6342831005dfbd5281fbeb29cb84e74ff483828f9e3ee0fcc517f" + "hash": "d1b4c05fc3f85f6412e22e47208952ef8f47e93152275994011c4382e85ed7f2" } diff --git a/examples/with-dioxus/realtime-todo-list/.sqlx/query-e7bce187d4ced5dfaa7bbc448fbb97af559e04fa847bfa0ec37ecc45354d3d22.json b/examples/with-dioxus/realtime-todo-list/.sqlx/query-e7bce187d4ced5dfaa7bbc448fbb97af559e04fa847bfa0ec37ecc45354d3d22.json new file mode 100644 index 00000000..a22ee7e8 --- /dev/null +++ b/examples/with-dioxus/realtime-todo-list/.sqlx/query-e7bce187d4ced5dfaa7bbc448fbb97af559e04fa847bfa0ec37ecc45354d3d22.json @@ -0,0 +1,52 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT id, email, name, password_hash as \"password_hash!\", created_at, updated_at\n FROM users WHERE id = $1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "email", + "type_info": "Varchar" + }, + { + "ordinal": 2, + "name": "name", + "type_info": "Varchar" + }, + { + "ordinal": 3, + "name": "password_hash!", + "type_info": "Text" + }, + { + "ordinal": 4, + "name": "created_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 5, + "name": "updated_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Uuid" + ] + }, + "nullable": [ + false, + false, + false, + true, + false, + false + ] + }, + "hash": "e7bce187d4ced5dfaa7bbc448fbb97af559e04fa847bfa0ec37ecc45354d3d22" +} diff --git a/examples/with-dioxus/realtime-todo-list/.sqlx/query-fe67fe1d5492dd97f324a0fa1a6b73e7b6282402018f150a58eefe67319ba763.json b/examples/with-dioxus/realtime-todo-list/.sqlx/query-fe67fe1d5492dd97f324a0fa1a6b73e7b6282402018f150a58eefe67319ba763.json deleted file mode 100644 index 71889f63..00000000 --- a/examples/with-dioxus/realtime-todo-list/.sqlx/query-fe67fe1d5492dd97f324a0fa1a6b73e7b6282402018f150a58eefe67319ba763.json +++ /dev/null @@ -1,38 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "SELECT * FROM todos ORDER BY created_at DESC", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "id", - "type_info": "Uuid" - }, - { - "ordinal": 1, - "name": "title", - "type_info": "Text" - }, - { - "ordinal": 2, - "name": "completed", - "type_info": "Bool" - }, - { - "ordinal": 3, - "name": "created_at", - "type_info": "Timestamptz" - } - ], - "parameters": { - "Left": [] - }, - "nullable": [ - false, - false, - false, - false - ] - }, - "hash": "fe67fe1d5492dd97f324a0fa1a6b73e7b6282402018f150a58eefe67319ba763" -} diff --git a/examples/with-dioxus/realtime-todo-list/Cargo.toml b/examples/with-dioxus/realtime-todo-list/Cargo.toml index 2d0a46ab..f77143b6 100644 --- a/examples/with-dioxus/realtime-todo-list/Cargo.toml +++ b/examples/with-dioxus/realtime-todo-list/Cargo.toml @@ -18,6 +18,8 @@ uuid = { version = "1", features = ["v4", "serde"] } chrono = { version = "0.4", features = ["serde"] } sqlx = { version = "0.8", features = ["runtime-tokio", "postgres", "chrono", "uuid", "macros", "derive"] } dotenvy = "0.15" +argon2 = "0.5" +password-hash = "0.5" rust-embed = { version = "8", optional = true } [build-dependencies] diff --git a/examples/with-dioxus/realtime-todo-list/Dockerfile b/examples/with-dioxus/realtime-todo-list/Dockerfile index 509573a0..99f22236 100644 --- a/examples/with-dioxus/realtime-todo-list/Dockerfile +++ b/examples/with-dioxus/realtime-todo-list/Dockerfile @@ -1,11 +1,11 @@ -FROM rust:1.92 AS dev +FROM rust:1.92-slim-bookworm AS dev WORKDIR /app RUN cargo install cargo-watch --locked RUN rustup target add wasm32-unknown-unknown RUN cargo install dioxus-cli --version 0.7.3 --locked RUN apt-get update && apt-get install -y curl pkg-config libssl-dev && rm -rf /var/lib/apt/lists/* -FROM rust:1.92 AS frontend-builder +FROM rust:1.92-slim-bookworm AS frontend-builder WORKDIR /app/examples/todo-dioxus/frontend RUN rustup target add wasm32-unknown-unknown RUN cargo install dioxus-cli --version 0.7.3 --locked @@ -14,7 +14,7 @@ COPY examples/todo-dioxus/frontend/.forge ./.forge COPY examples/todo-dioxus/frontend/src ./src RUN dx build --web --release -FROM rust:1.92 AS builder +FROM rust:1.92-slim-bookworm AS builder WORKDIR /app RUN apt-get update && apt-get install -y pkg-config libssl-dev && rm -rf /var/lib/apt/lists/* @@ -26,7 +26,7 @@ COPY --from=frontend-builder /app/examples/todo-dioxus/frontend/dist ./examples/ WORKDIR /app/examples/todo-dioxus RUN cargo build --release -FROM debian:bookworm-slim AS runtime +FROM debian:bookworm-20250203-slim AS runtime RUN apt-get update && apt-get install -y ca-certificates && rm -rf /var/lib/apt/lists/* WORKDIR /app COPY --from=builder /app/target/release/todo-dioxus /app/todo-dioxus diff --git a/examples/with-dioxus/realtime-todo-list/docker-compose.yml b/examples/with-dioxus/realtime-todo-list/docker-compose.yml index b4402e0c..69158e0d 100644 --- a/examples/with-dioxus/realtime-todo-list/docker-compose.yml +++ b/examples/with-dioxus/realtime-todo-list/docker-compose.yml @@ -6,7 +6,7 @@ services: target: dev working_dir: /workspace/examples/with-dioxus/realtime-todo-list ports: - - "9081:9081" + - "127.0.0.1:9081:9081" env_file: - .env environment: @@ -44,7 +44,7 @@ services: otel: build: ../../../docker/otel-lgtm ports: - - "3000:3000" + - "127.0.0.1:3000:3000" env_file: - .env environment: diff --git a/examples/with-dioxus/realtime-todo-list/forge.toml b/examples/with-dioxus/realtime-todo-list/forge.toml index 75f6f81f..3e711262 100644 --- a/examples/with-dioxus/realtime-todo-list/forge.toml +++ b/examples/with-dioxus/realtime-todo-list/forge.toml @@ -15,7 +15,7 @@ url = "${DATABASE_URL}" [gateway] port = 9081 cors_enabled = true -cors_origins = ["http://localhost:9080", "http://127.0.0.1:9080"] +cors_origins = ["${CORS_ORIGIN-http://localhost:9080}", "http://127.0.0.1:9080"] # request_timeout = "30s" # max_body_size = "10mb" # quiet_paths = ["/_api/health", "/_api/ready"] # Routes excluded from traces/metrics/logs @@ -39,7 +39,12 @@ otlp_endpoint = "http://localhost:4318" # job_timeout = "5m" # poll_interval = "1s" -# [auth] +[auth] +jwt_algorithm = "HS256" +jwt_secret = "${JWT_SECRET}" +jwt_audience = "${JWT_AUDIENCE}" + +# Legacy template values below kept for reference. # jwt_algorithm = "HS256" # HS256, HS384, HS512, RS256, RS384, RS512 # # --- HMAC (HS256/HS384/HS512) - symmetric, use for self-issued tokens --- @@ -48,9 +53,12 @@ otlp_endpoint = "http://localhost:4318" # --- RSA (RS256/RS384/RS512) - asymmetric, use for external providers --- # jwks_url = "" # Provider JWKS URLs: see auth reference docs -# [rate_limit] -# mode = "local" # local, distributed -# max_local_buckets = 10000 +[rate_limit] +# hybrid: per-node DashMap fast path (DDoS-grade). strict: PG round-trip (billing-grade). +mode = "hybrid" +max_local_buckets = 10000 +# Per-handler quotas live on the function macros, e.g. +# #[forge::mutation(rate_limit_requests = 10, rate_limit_per_secs = 60, rate_limit_key = "ip")] # [cluster] # name = "node-1" # auto-generated if omitted diff --git a/examples/with-dioxus/realtime-todo-list/frontend/playwright.config.ts b/examples/with-dioxus/realtime-todo-list/frontend/playwright.config.ts index a6685492..7a839a3a 100644 --- a/examples/with-dioxus/realtime-todo-list/frontend/playwright.config.ts +++ b/examples/with-dioxus/realtime-todo-list/frontend/playwright.config.ts @@ -6,7 +6,7 @@ export default defineConfig({ testDir: "./tests", fullyParallel: false, forbidOnly: !!process.env.CI, - retries: process.env.CI ? 1 : 1, + retries: process.env.CI ? 2 : 0, timeout: 180_000, workers: process.env.CI ? 1 : undefined, reporter: "html", diff --git a/examples/with-dioxus/realtime-todo-list/frontend/src/forge/api.rs b/examples/with-dioxus/realtime-todo-list/frontend/src/forge/api.rs index 99a20a4e..66adc0e2 100644 --- a/examples/with-dioxus/realtime-todo-list/frontend/src/forge/api.rs +++ b/examples/with-dioxus/realtime-todo-list/frontend/src/forge/api.rs @@ -29,6 +29,17 @@ pub fn use_list_todos() -> QueryState> { pub fn use_list_todos_subscription() -> SubscriptionState> { use_forge_subscription("list_todos", ()) } +pub async fn me(client: &ForgeClient) -> Result { + client.call("me", ()).await +} + +pub fn use_me() -> QueryState { + use_forge_query("me", ()) +} + +pub fn use_me_subscription() -> SubscriptionState { + use_forge_subscription("me", ()) +} pub async fn create_todo( client: &ForgeClient, args: CreateTodoInput, @@ -59,6 +70,36 @@ pub async fn delete_todo( pub fn use_delete_todo() -> Mutation { use_forge_mutation("delete_todo") } +pub async fn login( + client: &ForgeClient, + args: LoginInput, +) -> Result { + client.call("login", args).await +} + +pub fn use_login() -> Mutation { + use_forge_mutation("login") +} +pub async fn refresh_token( + client: &ForgeClient, + args: RefreshInput, +) -> Result { + client.call("refresh_token", args).await +} + +pub fn use_refresh_token() -> Mutation { + use_forge_mutation("refresh_token") +} +pub async fn register( + client: &ForgeClient, + args: RegisterInput, +) -> Result { + client.call("register", args).await +} + +pub fn use_register() -> Mutation { + use_forge_mutation("register") +} pub async fn update_todo( client: &ForgeClient, args: UpdateTodoInput, diff --git a/examples/with-dioxus/realtime-todo-list/frontend/src/forge/types.rs b/examples/with-dioxus/realtime-todo-list/frontend/src/forge/types.rs index 42c8491f..d77824ca 100644 --- a/examples/with-dioxus/realtime-todo-list/frontend/src/forge/types.rs +++ b/examples/with-dioxus/realtime-todo-list/frontend/src/forge/types.rs @@ -1,13 +1,29 @@ // @generated by FORGE - DO NOT EDIT -#![allow( - dead_code, - unused_imports, - clippy::redundant_field_names, - clippy::too_many_arguments -)] +#![allow(dead_code, unused_imports, clippy::too_many_arguments)] use serde::{Deserialize, Serialize}; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct AuthResponse { + pub access_token: String, + pub refresh_token: String, + pub user: UserPublic, +} + +impl AuthResponse { + pub fn new( + access_token: impl Into, + refresh_token: impl Into, + user: UserPublic, + ) -> Self { + Self { + access_token: access_token.into(), + refresh_token: refresh_token.into(), + user, + } + } +} + #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct CreateTodoInput { pub title: String, @@ -21,9 +37,59 @@ impl CreateTodoInput { } } +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct LoginInput { + pub email: String, + pub password: String, +} + +impl LoginInput { + pub fn new(email: impl Into, password: impl Into) -> Self { + Self { + email: email.into(), + password: password.into(), + } + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct RefreshInput { + pub refresh_token: String, +} + +impl RefreshInput { + pub fn new(refresh_token: impl Into) -> Self { + Self { + refresh_token: refresh_token.into(), + } + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct RegisterInput { + pub email: String, + pub name: String, + pub password: String, +} + +impl RegisterInput { + pub fn new( + email: impl Into, + name: impl Into, + password: impl Into, + ) -> Self { + Self { + email: email.into(), + name: name.into(), + password: password.into(), + } + } +} + #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct Todo { pub id: String, + pub user_id: String, pub title: String, pub completed: bool, pub created_at: String, @@ -55,3 +121,47 @@ impl UpdateTodoInput { self } } + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct User { + pub id: String, + pub email: String, + pub name: String, + pub created_at: String, + pub updated_at: String, +} + +impl User { + pub fn new( + id: impl Into, + email: impl Into, + name: impl Into, + created_at: impl Into, + updated_at: impl Into, + ) -> Self { + Self { + id: id.into(), + email: email.into(), + name: name.into(), + created_at: created_at.into(), + updated_at: updated_at.into(), + } + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct UserPublic { + pub id: String, + pub email: String, + pub name: String, +} + +impl UserPublic { + pub fn new(id: impl Into, email: impl Into, name: impl Into) -> Self { + Self { + id: id.into(), + email: email.into(), + name: name.into(), + } + } +} diff --git a/examples/with-dioxus/realtime-todo-list/frontend/src/main.rs b/examples/with-dioxus/realtime-todo-list/frontend/src/main.rs index 6d95b780..40e94038 100644 --- a/examples/with-dioxus/realtime-todo-list/frontend/src/main.rs +++ b/examples/with-dioxus/realtime-todo-list/frontend/src/main.rs @@ -3,7 +3,7 @@ mod todo_app; mod todo_item; use dioxus::prelude::*; -use forge::ForgeProvider; +use forge::ForgeAuthProvider; use todo_app::TodoApp; fn api_url() -> &'static str { @@ -19,8 +19,9 @@ fn App() -> Element { rsx! { document::Title { "Todos" } document::Stylesheet { href: asset!("/public/style.css") } - ForgeProvider { + ForgeAuthProvider { url: api_url().to_string(), + app_name: "todo-dioxus".to_string(), TodoApp {} } } diff --git a/examples/with-dioxus/realtime-todo-list/frontend/src/todo_app.rs b/examples/with-dioxus/realtime-todo-list/frontend/src/todo_app.rs index ec5980e0..dacb6a85 100644 --- a/examples/with-dioxus/realtime-todo-list/frontend/src/todo_app.rs +++ b/examples/with-dioxus/realtime-todo-list/frontend/src/todo_app.rs @@ -2,11 +2,164 @@ use dioxus::prelude::*; use forge_dioxus::use_signals; use serde_json::json; -use crate::forge::{CreateTodoInput, use_create_todo, use_list_todos_subscription}; +use crate::forge::{ + CreateTodoInput, LoginInput, RegisterInput, UserPublic, use_create_todo, use_forge_auth, + use_list_todos_subscription, use_login, use_register, +}; use crate::todo_item::TodoItem; #[component] pub fn TodoApp() -> Element { + let auth = use_forge_auth(); + + rsx! { + main { + div { class: "shell", + header { class: "hero", + h1 { "Todos" } + if auth.is_authenticated() { + UserBar {} + } + } + if auth.is_authenticated() { + TodoList {} + } else { + AuthPanel {} + } + } + } + } +} + +#[component] +fn UserBar() -> Element { + let mut auth = use_forge_auth(); + let viewer = auth.viewer::(); + let label = viewer + .as_ref() + .map(|u| u.name.clone()) + .unwrap_or_default(); + + rsx! { + div { class: "user-row", + span { class: "user", "{label}" } + button { + class: "logout", + onclick: move |_| auth.logout(), + "Sign out" + } + } + } +} + +#[component] +fn AuthPanel() -> Element { + let mut auth = use_forge_auth(); + let signals = use_signals(); + let login_mut = use_login(); + let register_mut = use_register(); + + let mut mode = use_signal(|| "login".to_string()); + let mut email = use_signal(String::new); + let mut name = use_signal(String::new); + let mut password = use_signal(String::new); + let mut error = use_signal(|| None::); + let mut loading = use_signal(|| false); + + let handle_submit = { + let login_mut = login_mut.clone(); + let register_mut = register_mut.clone(); + let signals = signals.clone(); + move |evt: FormEvent| { + evt.prevent_default(); + let is_register = mode.read().as_str() == "register"; + let e = email.read().clone(); + let n = name.read().clone(); + let p = password.read().clone(); + let login_mut = login_mut.clone(); + let register_mut = register_mut.clone(); + let signals = signals.clone(); + spawn(async move { + loading.set(true); + error.set(None); + let res = if is_register { + register_mut.call(RegisterInput::new(&e, &n, &p)).await + } else { + login_mut.call(LoginInput::new(&e, &p)).await + }; + match res { + Ok(r) => { + signals.track_with_properties( + "auth_success", + json!({"mode": is_register}), + ); + auth.login_with_viewer( + r.access_token.clone(), + r.refresh_token.clone(), + &r.user, + ); + } + Err(e) => error.set(Some(e.message)), + } + loading.set(false); + }); + } + }; + + rsx! { + section { class: "auth-panel", + div { class: "tabs", + button { + class: if mode.read().as_str() == "login" { "active" } else { "" }, + onclick: move |_| mode.set("login".into()), + "Sign in" + } + button { + class: if mode.read().as_str() == "register" { "active" } else { "" }, + onclick: move |_| mode.set("register".into()), + "Sign up" + } + } + form { onsubmit: handle_submit, + if mode.read().as_str() == "register" { + input { + r#type: "text", + placeholder: "Name", + value: "{name}", + oninput: move |e: FormEvent| name.set(e.value()), + required: true, + } + } + input { + r#type: "email", + placeholder: "Email", + value: "{email}", + oninput: move |e: FormEvent| email.set(e.value()), + required: true, + } + input { + r#type: "password", + placeholder: "Password (min 8 chars)", + value: "{password}", + oninput: move |e: FormEvent| password.set(e.value()), + minlength: "8", + required: true, + } + button { + r#type: "submit", + disabled: loading(), + if loading() { "..." } else if mode.read().as_str() == "login" { "Sign in" } else { "Sign up" } + } + } + if let Some(msg) = error() { + p { class: "error", "{msg}" } + } + } + } +} + +#[component] +fn TodoList() -> Element { let signals = use_signals(); let create_todo = use_create_todo(); let todo_state = use_list_todos_subscription(); @@ -17,121 +170,91 @@ pub fn TodoApp() -> Element { let todo_items = todo_state.data.clone().unwrap_or_default(); let remaining_count = todo_items.iter().filter(|t| !t.completed).count(); - rsx! { - main { - div { - class: "shell", - header { - class: "hero", - h1 { "Todos" } - } - - section { - class: "input-panel", - div { - class: "input-row", - input { - r#type: "text", - placeholder: "What needs to be done?", - value: new_title(), - disabled: adding(), - oninput: move |event| new_title.set(event.value()), - onkeydown: { - let create_todo = create_todo.clone(); - let signals = signals.clone(); - move |event: KeyboardEvent| { - if event.key().to_string() == "Enter" { - let title = new_title().trim().to_string(); - if title.is_empty() || adding() { - return; - } - error.set(None); - adding.set(true); - let create_todo = create_todo.clone(); - let signals = signals.clone(); - spawn(async move { - match create_todo.call(CreateTodoInput::new(title.clone())).await { - Ok(_) => { - signals.track_with_properties("todo_created", json!({"title": &title})); - new_title.set(String::new()); - } - Err(err) => { - signals.track_with_properties("todo_create_error", json!({"error": &err.message})); - error.set(Some(err.message)); - } - } - adding.set(false); - }); - } - } - }, - } - button { - disabled: adding() || new_title().trim().is_empty(), - onclick: { - let create_todo = create_todo.clone(); - let signals = signals.clone(); - move |_| { - let title = new_title().trim().to_string(); - if title.is_empty() || adding() { - return; - } - error.set(None); - adding.set(true); - let create_todo = create_todo.clone(); - let signals = signals.clone(); - spawn(async move { - match create_todo.call(CreateTodoInput::new(title.clone())).await { - Ok(_) => { - signals.track_with_properties("todo_created", json!({"title": &title})); - new_title.set(String::new()); - } - Err(err) => { - signals.track_with_properties("todo_create_error", json!({"error": &err.message})); - error.set(Some(err.message)); - } - } - adding.set(false); - }); - } - }, - if adding() { "Adding..." } else { "Add" } - } + let submit = { + let create_todo = create_todo.clone(); + let signals = signals.clone(); + move || { + let title = new_title().trim().to_string(); + if title.is_empty() || adding() { + return; + } + error.set(None); + adding.set(true); + let create_todo = create_todo.clone(); + let signals = signals.clone(); + spawn(async move { + match create_todo.call(CreateTodoInput::new(title.clone())).await { + Ok(_) => { + signals.track_with_properties("todo_created", json!({"title": &title})); + new_title.set(String::new()); } - - if let Some(message) = error() { - p { class: "error", "{message}" } + Err(err) => { + signals.track_with_properties( + "todo_create_error", + json!({"error": &err.message}), + ); + error.set(Some(err.message)); } } + adding.set(false); + }); + } + }; - section { - class: "list-panel", - if !todo_items.is_empty() { - div { - class: "list-head", - span { class: "summary", "{remaining_count} remaining" } - } - } - - if todo_state.loading { - p { class: "status", "Loading..." } - } else if let Some(todo_error) = todo_state.error.as_ref() { - p { class: "error", "{todo_error.message}" } - } else if todo_items.is_empty() { - p { class: "status", "No todos yet. Add one above!" } - } else { - ul { - for todo in todo_items { - TodoItem { - key: "{todo.id}", - todo: todo, - error: error, - } + rsx! { + section { class: "input-panel", + div { class: "input-row", + input { + r#type: "text", + placeholder: "What needs to be done?", + value: new_title(), + disabled: adding(), + oninput: move |event| new_title.set(event.value()), + onkeydown: { + let mut submit = submit.clone(); + move |event: KeyboardEvent| { + if event.key().to_string() == "Enter" { + submit(); } } - p { class: "count", "{remaining_count} remaining" } + }, + } + button { + disabled: adding() || new_title().trim().is_empty(), + onclick: { + let mut submit = submit.clone(); + move |_| submit() + }, + if adding() { "Adding..." } else { "Add" } + } + } + if let Some(message) = error() { + p { class: "error", "{message}" } + } + } + section { class: "list-panel", + if !todo_items.is_empty() { + div { class: "list-head", + span { class: "summary", "{remaining_count} remaining" } + } + } + if todo_state.loading { + p { class: "status", "Loading..." } + } else if let Some(todo_error) = todo_state.error.as_ref() { + p { class: "error", "{todo_error.message}" } + } else if todo_items.is_empty() { + p { class: "status", "No todos yet. Add one above!" } + } else { + ul { + for todo in todo_items { + TodoItem { + key: "{todo.id}", + todo: todo, + error: error, + } } } + p { class: "count", "{remaining_count} remaining" } } } } diff --git a/examples/with-dioxus/realtime-todo-list/frontend/tests/home.spec.ts b/examples/with-dioxus/realtime-todo-list/frontend/tests/home.spec.ts index d2495029..3a663824 100644 --- a/examples/with-dioxus/realtime-todo-list/frontend/tests/home.spec.ts +++ b/examples/with-dioxus/realtime-todo-list/frontend/tests/home.spec.ts @@ -5,42 +5,55 @@ import { uniqueId, trackConsoleErrors, } from "./fixtures"; +import type { Page } from "@playwright/test"; const INPUT = 'input[placeholder="What needs to be done?"]'; +const EMAIL = 'input[type="email"]'; +const PASSWORD = 'input[type="password"]'; +const NAME = 'input[placeholder="Name"]'; -async function deleteAllTodos( - rpc: (fn: string, args?: unknown) => Promise, +async function signUp( + page: Page, + email: string, + name: string, + password: string, ) { - const todos = await rpc("list_todos"); - if (!Array.isArray(todos)) return; - for (const todo of todos) { - await rpc("delete_todo", { id: todo.id }); - } + await page.getByRole("button", { name: "Sign up" }).first().click(); + await page.fill(NAME, name); + await page.fill(EMAIL, email); + await page.fill(PASSWORD, password); + await page.getByRole("button", { name: "Sign up" }).last().click(); + await expect(page.locator(INPUT)).toBeVisible({ timeout: ACTION_TIMEOUT }); } -test.beforeEach(async ({ rpc }) => { - await deleteAllTodos(rpc); -}); - -test.afterEach(async ({ rpc }) => { - await deleteAllTodos(rpc); -}); +// The app only subscribes to the todos query once authenticated, so reactivity +// readiness can't be detected until after sign-up. Arm the subscribe wait +// before submitting, then await it once the authed view renders. (×3 timeout +// for the WASM download → instantiate → init → SSE → subscribe path.) +async function signUpReady( + page: Page, + email: string, + name: string, + password: string, +) { + const subscribed = page.waitForResponse( + (res) => res.url().includes("/_api/subscribe") && res.status() === 200, + { timeout: ACTION_TIMEOUT * 3 }, + ); + await signUp(page, email, name, password); + await subscribed; +} -test("todo flow stays reactive through create, toggle, and delete", async ({ +test("authenticated user can create, toggle, and delete their todos", async ({ page, - gotoReady, }) => { - const title = uniqueId("release"); const errors = trackConsoleErrors(page); + const email = `${uniqueId("user")}@test.local`; - await gotoReady(); - await expect(page.locator("h1")).toHaveText("Todos"); - await expect( - page.locator(".status", { hasText: "No todos yet" }), - ).toBeVisible({ - timeout: ACTION_TIMEOUT, - }); + await page.goto("/"); + await signUpReady(page, email, "Solo", "password123"); + const title = uniqueId("release"); await page.fill(INPUT, title); await page.click(".input-row button"); @@ -49,22 +62,44 @@ test("todo flow stays reactive through create, toggle, and delete", async ({ await expect(page.locator(".count")).toHaveText("1 remaining", { timeout: ACTION_TIMEOUT, }); - await expect(page.locator(INPUT)).toHaveValue("", { - timeout: ACTION_TIMEOUT, - }); await todoItem.locator("button.toggle").click(); await expect(todoItem).toHaveClass(/completed/, { timeout: ACTION_TIMEOUT }); - await expect(page.locator(".count")).toHaveText("0 remaining", { - timeout: ACTION_TIMEOUT, - }); await todoItem.locator("button.delete").click(); await expect(todoItem).not.toBeVisible({ timeout: ACTION_TIMEOUT }); - await expect( - page.locator(".status", { hasText: "No todos yet" }), - ).toBeVisible({ + expect(errors).toHaveLength(0); +}); + +test("two users cannot see each other's todos", async ({ browser }) => { + const aliceEmail = `${uniqueId("alice")}@test.local`; + const bobEmail = `${uniqueId("bob")}@test.local`; + const aliceTitle = uniqueId("alice-task"); + const bobTitle = uniqueId("bob-task"); + + const aliceCtx = await browser.newContext(); + const alice = await aliceCtx.newPage(); + await alice.goto("/"); + await signUpReady(alice, aliceEmail, "Alice", "password123"); + await alice.fill(INPUT, aliceTitle); + await alice.click(".input-row button"); + await expect(alice.locator("li", { hasText: aliceTitle })).toBeVisible({ timeout: ACTION_TIMEOUT, }); - expect(errors).toHaveLength(0); + + const bobCtx = await browser.newContext(); + const bob = await bobCtx.newPage(); + await bob.goto("/"); + await signUpReady(bob, bobEmail, "Bob", "password123"); + await bob.fill(INPUT, bobTitle); + await bob.click(".input-row button"); + await expect(bob.locator("li", { hasText: bobTitle })).toBeVisible({ + timeout: ACTION_TIMEOUT, + }); + + await expect(bob.locator("li", { hasText: aliceTitle })).toHaveCount(0); + await expect(alice.locator("li", { hasText: bobTitle })).toHaveCount(0); + + await aliceCtx.close(); + await bobCtx.close(); }); diff --git a/examples/with-dioxus/realtime-todo-list/migrations/0001_todos.sql b/examples/with-dioxus/realtime-todo-list/migrations/0001_todos.sql index e717bf19..4f584e0f 100644 --- a/examples/with-dioxus/realtime-todo-list/migrations/0001_todos.sql +++ b/examples/with-dioxus/realtime-todo-list/migrations/0001_todos.sql @@ -1,8 +1,22 @@ -CREATE TABLE todos ( +CREATE TABLE IF NOT EXISTS users ( + id UUID PRIMARY KEY, + email VARCHAR(255) NOT NULL, + name VARCHAR(255) NOT NULL, + password_hash TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE UNIQUE INDEX IF NOT EXISTS idx_users_email ON users(email); + +CREATE TABLE IF NOT EXISTS todos ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, title TEXT NOT NULL, completed BOOLEAN NOT NULL DEFAULT false, created_at TIMESTAMPTZ NOT NULL DEFAULT now() ); +CREATE INDEX IF NOT EXISTS idx_todos_user_id ON todos(user_id); + SELECT forge_enable_reactivity('todos'); diff --git a/examples/with-dioxus/realtime-todo-list/src/functions/auth.rs b/examples/with-dioxus/realtime-todo-list/src/functions/auth.rs new file mode 100644 index 00000000..b23c4ff0 --- /dev/null +++ b/examples/with-dioxus/realtime-todo-list/src/functions/auth.rs @@ -0,0 +1,140 @@ +use crate::schema::{AuthResponse, User, UserPublic}; +use forge::prelude::*; + +#[derive(Debug, Serialize, Deserialize)] +pub struct RegisterInput { + pub email: String, + pub name: String, + pub password: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct LoginInput { + pub email: String, + pub password: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct RefreshInput { + pub refresh_token: String, +} + +async fn auth_response(ctx: &MutationContext, user: &User) -> Result { + let pair = ctx.issue_token_pair(user.id, &["user"]).await?; + Ok(AuthResponse { + access_token: pair.access_token, + refresh_token: pair.refresh_token, + user: UserPublic::from(user.clone()), + }) +} + +fn validate_register(input: &RegisterInput) -> Result<(String, String)> { + let email = input.email.trim(); + if email.is_empty() { + return Err(ForgeError::Validation("Email is required".into())); + } + let name = input.name.trim(); + if name.is_empty() { + return Err(ForgeError::Validation("Name is required".into())); + } + if input.password.len() < 8 { + return Err(ForgeError::Validation( + "Password must be at least 8 characters".into(), + )); + } + Ok((email.to_string(), name.to_string())) +} + +#[forge::mutation(auth = "none")] +pub async fn register(ctx: &MutationContext, input: RegisterInput) -> Result { + let (email, name) = validate_register(&input)?; + + let password_hash = { + use argon2::PasswordHasher; + use password_hash::SaltString; + let salt = SaltString::generate(&mut password_hash::rand_core::OsRng); + argon2::Argon2::default() + .hash_password(input.password.as_bytes(), &salt) + .map_err(|e| ForgeError::internal(e.to_string()))? + .to_string() + }; + + let id = Uuid::new_v4(); + let now = Utc::now(); + let mut conn = ctx.conn().await?; + + let user = sqlx::query_as!( + User, + r#" + INSERT INTO users (id, email, name, password_hash, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, $6) + RETURNING id, email, name, password_hash as "password_hash!", created_at, updated_at + "#, + id, + &email, + &name, + &password_hash, + now, + now + ) + .fetch_one(&mut conn) + .await + .map_err(|e| match &e { + sqlx::Error::Database(db_err) if db_err.constraint() == Some("idx_users_email") => { + ForgeError::Validation("Email already registered".into()) + } + _ => ForgeError::from(e), + })?; + + auth_response(ctx, &user).await +} + +#[forge::mutation(auth = "none")] +pub async fn login(ctx: &MutationContext, input: LoginInput) -> Result { + let mut conn = ctx.conn().await?; + + let user = sqlx::query_as!( + User, + r#" + SELECT id, email, name, password_hash as "password_hash!", created_at, updated_at + FROM users WHERE email = $1 + "#, + &input.email + ) + .fetch_optional(&mut conn) + .await? + .ok_or_else(|| ForgeError::Validation("Invalid email or password".into()))?; + + { + use argon2::PasswordVerifier; + let parsed = password_hash::PasswordHash::new(&user.password_hash) + .map_err(|e| ForgeError::internal(e.to_string()))?; + argon2::Argon2::default() + .verify_password(input.password.as_bytes(), &parsed) + .map_err(|_| ForgeError::Validation("Invalid email or password".into()))?; + } + + auth_response(ctx, &user).await +} + +#[forge::mutation(auth = "none")] +pub async fn refresh_token(ctx: &MutationContext, input: RefreshInput) -> Result { + ctx.rotate_refresh_token(&input.refresh_token).await +} + +#[forge::query(scope = "global")] +pub async fn me(ctx: &QueryContext) -> Result { + let user_id = ctx.user_id()?; + let user = sqlx::query_as!( + User, + r#" + SELECT id, email, name, password_hash as "password_hash!", created_at, updated_at + FROM users WHERE id = $1 + "#, + user_id + ) + .fetch_optional(ctx.db()) + .await? + .ok_or_else(|| ForgeError::NotFound("User not found".into()))?; + Ok(UserPublic::from(user)) +} diff --git a/examples/with-dioxus/realtime-todo-list/src/functions/mod.rs b/examples/with-dioxus/realtime-todo-list/src/functions/mod.rs index 2fd7ed7a..5afdcc48 100644 --- a/examples/with-dioxus/realtime-todo-list/src/functions/mod.rs +++ b/examples/with-dioxus/realtime-todo-list/src/functions/mod.rs @@ -1 +1,2 @@ +mod auth; mod todos; diff --git a/examples/with-dioxus/realtime-todo-list/src/functions/todos.rs b/examples/with-dioxus/realtime-todo-list/src/functions/todos.rs index 657492ec..fd49d3ab 100644 --- a/examples/with-dioxus/realtime-todo-list/src/functions/todos.rs +++ b/examples/with-dioxus/realtime-todo-list/src/functions/todos.rs @@ -15,26 +15,33 @@ pub struct UpdateTodoInput { pub completed: Option, } -#[forge::query(auth = "none", tables("todos"))] +#[forge::query(tables("todos"))] pub async fn list_todos(ctx: &QueryContext) -> Result> { - sqlx::query_as!(Todo, "SELECT * FROM todos ORDER BY created_at DESC") - .fetch_all(ctx.db()) - .await - .map_err(Into::into) + let user_id = ctx.user_id()?; + sqlx::query_as!( + Todo, + "SELECT * FROM todos WHERE user_id = $1 ORDER BY created_at DESC", + user_id + ) + .fetch_all(ctx.db()) + .await + .map_err(Into::into) } -#[forge::mutation(auth = "none")] +#[forge::mutation(scope = "global")] pub async fn create_todo(ctx: &MutationContext, input: CreateTodoInput) -> Result { if input.title.trim().is_empty() { return Err(ForgeError::Validation("Title cannot be empty".into())); } + let user_id = ctx.user_id()?; let title = input.title.trim().to_string(); let mut conn = ctx.conn().await?; sqlx::query_as!( Todo, - "INSERT INTO todos (title) VALUES ($1) RETURNING *", + "INSERT INTO todos (user_id, title) VALUES ($1, $2) RETURNING *", + user_id, title ) .fetch_one(&mut conn) @@ -42,8 +49,9 @@ pub async fn create_todo(ctx: &MutationContext, input: CreateTodoInput) -> Resul .map_err(Into::into) } -#[forge::mutation(auth = "none")] +#[forge::mutation] pub async fn update_todo(ctx: &MutationContext, input: UpdateTodoInput) -> Result { + let user_id = ctx.user_id()?; let title = input.title.as_deref(); let mut conn = ctx.conn().await?; @@ -52,24 +60,174 @@ pub async fn update_todo(ctx: &MutationContext, input: UpdateTodoInput) -> Resul "UPDATE todos SET title = COALESCE($1, title), completed = COALESCE($2, completed) - WHERE id = $3 + WHERE id = $3 AND user_id = $4 RETURNING *", title, input.completed, - input.id + input.id, + user_id ) .fetch_optional(&mut conn) .await? .ok_or_else(|| ForgeError::NotFound("Todo not found".into())) } -#[forge::mutation(auth = "none")] +#[forge::mutation] pub async fn delete_todo(ctx: &MutationContext, id: Uuid) -> Result { + let user_id = ctx.user_id()?; let mut conn = ctx.conn().await?; - let result = sqlx::query!("DELETE FROM todos WHERE id = $1", id) - .execute(&mut conn) - .await?; + let result = sqlx::query!( + "DELETE FROM todos WHERE id = $1 AND user_id = $2", + id, + user_id + ) + .execute(&mut conn) + .await?; Ok(result.rows_affected() > 0) } + +#[cfg(all(test, feature = "testcontainers"))] +mod tests { + use super::*; + use forge::forge_core::function::{AuthContext, RequestMetadata}; + use forge::testing::{IsolatedTestDb, TestDatabase}; + use std::path::Path; + + async fn setup_db() -> IsolatedTestDb { + let base = TestDatabase::from_env().await.expect("test db"); + let db = base.isolated("todos_test").await.expect("isolated db"); + db.run_sql(&forge::get_internal_sql()) + .await + .expect("internal sql"); + db.migrate(Path::new("migrations")) + .await + .expect("migrations"); + db + } + + async fn seed_user(pool: &sqlx::PgPool) -> Uuid { + let id = Uuid::new_v4(); + sqlx::query!( + "INSERT INTO users (id, email, name, password_hash) VALUES ($1, $2, $3, $4)", + id, + format!("{id}@test.local"), + "Test User", + "x" + ) + .execute(pool) + .await + .expect("seed user"); + id + } + + fn query_ctx(pool: sqlx::PgPool, user_id: Uuid) -> QueryContext { + QueryContext::new( + pool, + AuthContext::authenticated(user_id, vec!["user".into()], Default::default()), + RequestMetadata::default(), + ) + } + + fn mutation_ctx(pool: sqlx::PgPool, user_id: Uuid) -> MutationContext { + MutationContext::new( + pool, + AuthContext::authenticated(user_id, vec!["user".into()], Default::default()), + RequestMetadata::default(), + ) + } + + #[tokio::test] + async fn create_todo_trims_and_persists_title() { + let db = setup_db().await; + let uid = seed_user(db.pool()).await; + let ctx = mutation_ctx(db.pool().clone(), uid); + + let todo = create_todo( + &ctx, + CreateTodoInput { + title: " ship tests ".into(), + }, + ) + .await + .expect("create"); + + assert_eq!(todo.title, "ship tests"); + assert_eq!(todo.user_id, uid); + assert!(!todo.completed); + db.cleanup().await.expect("cleanup"); + } + + #[tokio::test] + async fn list_todos_isolates_by_user() { + let db = setup_db().await; + let alice = seed_user(db.pool()).await; + let bob = seed_user(db.pool()).await; + + let alice_mut = mutation_ctx(db.pool().clone(), alice); + let bob_mut = mutation_ctx(db.pool().clone(), bob); + create_todo( + &alice_mut, + CreateTodoInput { + title: "alice".into(), + }, + ) + .await + .expect("alice todo"); + create_todo( + &bob_mut, + CreateTodoInput { + title: "bob".into(), + }, + ) + .await + .expect("bob todo"); + + let alice_q = query_ctx(db.pool().clone(), alice); + let bob_q = query_ctx(db.pool().clone(), bob); + let alice_todos = list_todos(&alice_q).await.expect("alice list"); + let bob_todos = list_todos(&bob_q).await.expect("bob list"); + + assert_eq!(alice_todos.len(), 1); + assert_eq!(alice_todos[0].title, "alice"); + assert_eq!(bob_todos.len(), 1); + assert_eq!(bob_todos[0].title, "bob"); + db.cleanup().await.expect("cleanup"); + } + + #[tokio::test] + async fn update_todo_blocks_other_users() { + let db = setup_db().await; + let alice = seed_user(db.pool()).await; + let bob = seed_user(db.pool()).await; + + let alice_mut = mutation_ctx(db.pool().clone(), alice); + let todo = create_todo( + &alice_mut, + CreateTodoInput { + title: "hers".into(), + }, + ) + .await + .expect("create"); + + let bob_mut = mutation_ctx(db.pool().clone(), bob); + let err = update_todo( + &bob_mut, + UpdateTodoInput { + id: todo.id, + title: Some("stolen".into()), + completed: None, + }, + ) + .await + .expect_err("bob must not update alice's todo"); + assert!(matches!(err, ForgeError::NotFound(_))); + + let deleted = delete_todo(&bob_mut, todo.id).await.expect("delete call"); + assert!(!deleted, "bob must not delete alice's todo"); + + db.cleanup().await.expect("cleanup"); + } +} diff --git a/examples/with-dioxus/realtime-todo-list/src/schema/todo.rs b/examples/with-dioxus/realtime-todo-list/src/schema/todo.rs index fccb7763..45fa3645 100644 --- a/examples/with-dioxus/realtime-todo-list/src/schema/todo.rs +++ b/examples/with-dioxus/realtime-todo-list/src/schema/todo.rs @@ -6,7 +6,43 @@ use uuid::Uuid; #[forge::model] pub struct Todo { pub id: Uuid, + pub user_id: Uuid, pub title: String, pub completed: bool, pub created_at: DateTime, } + +#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] +pub struct User { + pub id: Uuid, + pub email: String, + pub name: String, + pub created_at: DateTime, + pub updated_at: DateTime, + #[serde(skip_serializing)] + pub password_hash: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UserPublic { + pub id: Uuid, + pub email: String, + pub name: String, +} + +impl From for UserPublic { + fn from(u: User) -> Self { + Self { + id: u.id, + email: u.email, + name: u.name, + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct AuthResponse { + pub access_token: String, + pub refresh_token: String, + pub user: UserPublic, +} diff --git a/examples/with-svelte/demo/.env b/examples/with-svelte/demo/.env index 7344ecbe..60da3aae 100644 --- a/examples/with-svelte/demo/.env +++ b/examples/with-svelte/demo/.env @@ -1,21 +1,17 @@ -# Server +# Dev-only environment for `forge test` and local runs. NOT shipped to users: +# `scripts/build-template-archive.sh` excludes `.env`, and the webhook secret is +# used server-side only (never in the browser bundle). Users copy `.env.example` +# and generate their own secrets. Mirrors the realtime-todo-list convention. HOST=0.0.0.0 PORT=9081 - -# Logging (error, warn, info, debug, trace) RUST_LOG=info,forge_runtime::function::executor=trace - -# Postgres container settings POSTGRES_USER=postgres POSTGRES_PASSWORD=forge POSTGRES_DB=forge_svelte_demo_template POSTGRES_PORT=5432 - -# JWT secret for authentication -JWT_SECRET=demo-jwt-secret-change-me-in-production - -# Webhook secret for HMAC-SHA256 signature validation +JWT_SECRET=dev-jwt-secret-not-for-production-use-please-rotate +JWT_AUDIENCE=forge-demo-dev WEBHOOK_SECRET=demo-secret - -# Enable offline mode for sqlx compile-time checks +SEED_DEMO_USER=true +CORS_ORIGIN=http://localhost:9080 SQLX_OFFLINE=true diff --git a/examples/with-svelte/demo/.env.example b/examples/with-svelte/demo/.env.example index 7344ecbe..c0e7972e 100644 --- a/examples/with-svelte/demo/.env.example +++ b/examples/with-svelte/demo/.env.example @@ -1,3 +1,5 @@ +# Copy to `.env` and fill in real values. Never commit `.env`. + # Server HOST=0.0.0.0 PORT=9081 @@ -11,11 +13,22 @@ POSTGRES_PASSWORD=forge POSTGRES_DB=forge_svelte_demo_template POSTGRES_PORT=5432 -# JWT secret for authentication -JWT_SECRET=demo-jwt-secret-change-me-in-production +# JWT signing secret. Generate with: openssl rand -base64 32 +JWT_SECRET=CHANGE_ME_USE_OPENSSL_RAND_BASE64_32 + +# JWT audience claim. Must match the audience configured in your auth provider. +JWT_AUDIENCE=CHANGE_ME_YOUR_AUDIENCE + +# HMAC secret used to verify inbound webhook signatures. +# Generate with: openssl rand -hex 32 +WEBHOOK_SECRET=CHANGE_ME_USE_OPENSSL_RAND_HEX_32 + +# Seed the demo user (demo@example.com / password123) at first migration. +# DEV ONLY. Leave unset (or `false`) in any deployed environment. +SEED_DEMO_USER=true -# Webhook secret for HMAC-SHA256 signature validation -WEBHOOK_SECRET=demo-secret +# CORS origin for the SvelteKit frontend. Override per environment. +CORS_ORIGIN=http://localhost:9080 # Enable offline mode for sqlx compile-time checks SQLX_OFFLINE=true diff --git a/examples/with-svelte/demo/.gitignore b/examples/with-svelte/demo/.gitignore index 6c4eb93f..c082ec29 100644 --- a/examples/with-svelte/demo/.gitignore +++ b/examples/with-svelte/demo/.gitignore @@ -13,6 +13,8 @@ frontend/playwright-report/ frontend/test-results/ # Environment +# `.env` is tracked with dev-only secrets so `forge test` works from a clean +# checkout; the template archive excludes it (see build-template-archive.sh). .env.local .env.*.local diff --git a/examples/with-svelte/demo/Cargo.toml b/examples/with-svelte/demo/Cargo.toml index 080cf5c3..f2df1231 100644 --- a/examples/with-svelte/demo/Cargo.toml +++ b/examples/with-svelte/demo/Cargo.toml @@ -6,7 +6,7 @@ rust-version = "1.92" publish = false [features] -default = ["embedded-frontend"] +default = [] embedded-frontend = ["dep:rust-embed", "forge/embedded-frontend"] testcontainers = ["forge/testcontainers"] @@ -25,6 +25,9 @@ tokio-tungstenite = { version = "0.26", features = ["rustls-tls-webpki-roots"] } futures-util = "0.3" argon2 = "0.5" password-hash = "0.5" +hmac = "0.12" +sha2 = "0.10" +hex = "0.4" rust-embed = { version = "8", optional = true } [build-dependencies] diff --git a/examples/with-svelte/demo/Dockerfile b/examples/with-svelte/demo/Dockerfile index e9d2f26c..57890816 100644 --- a/examples/with-svelte/demo/Dockerfile +++ b/examples/with-svelte/demo/Dockerfile @@ -1,4 +1,4 @@ -FROM rust:1-slim-bookworm AS dev +FROM rust:1.92-slim-bookworm AS dev WORKDIR /app @@ -15,19 +15,19 @@ RUN cargo install cargo-watch --locked && \ # Development command - no frontend embedding, watch for changes CMD ["cargo", "watch", "-x", "run --no-default-features"] -FROM oven/bun:1-alpine AS frontend-builder +FROM oven/bun:1.1.34-alpine AS frontend-builder WORKDIR /app/frontend COPY frontend/package.json frontend/bun.lock* ./ -RUN bun install --frozen-lockfile || bun install +RUN bun install --frozen-lockfile COPY frontend ./ RUN bun run build -FROM rust:1-alpine AS builder +FROM rust:1.92-alpine AS builder WORKDIR /app diff --git a/examples/with-svelte/demo/docker-compose.yml b/examples/with-svelte/demo/docker-compose.yml index 18ed2553..eb9c716b 100644 --- a/examples/with-svelte/demo/docker-compose.yml +++ b/examples/with-svelte/demo/docker-compose.yml @@ -5,7 +5,7 @@ services: dockerfile: Dockerfile target: dev ports: - - "9081:9081" + - "127.0.0.1:9081:9081" env_file: - .env environment: @@ -34,7 +34,7 @@ services: working_dir: /app command: sh -c "bun install && bun run dev --host 0.0.0.0 --port 9080" ports: - - "9080:9080" + - "127.0.0.1:9080:9080" env_file: - ./frontend/.env volumes: @@ -60,7 +60,7 @@ services: otel: build: ../../../docker/otel-lgtm ports: - - "3000:3000" + - "127.0.0.1:3000:3000" env_file: - .env environment: diff --git a/examples/with-svelte/demo/forge.toml b/examples/with-svelte/demo/forge.toml index dec1a2c7..ad7aa4a4 100644 --- a/examples/with-svelte/demo/forge.toml +++ b/examples/with-svelte/demo/forge.toml @@ -15,7 +15,7 @@ url = "${DATABASE_URL}" [gateway] port = 9081 cors_enabled = true -cors_origins = ["http://localhost:9080", "http://127.0.0.1:9080"] +cors_origins = ["${CORS_ORIGIN-http://localhost:9080}", "http://127.0.0.1:9080"] # request_timeout = "30s" # max_body_size = "10mb" # quiet_paths = ["/_api/health", "/_api/ready"] # Routes excluded from traces/metrics/logs @@ -42,7 +42,7 @@ otlp_endpoint = "${FORGE_OTEL_ENDPOINT-http://localhost:4318}" [auth] jwt_algorithm = "HS256" jwt_secret = "${JWT_SECRET}" -jwt_audience = "${JWT_AUDIENCE-https://api.forge-demo.local}" +jwt_audience = "${JWT_AUDIENCE}" [mcp] enabled = true @@ -52,9 +52,18 @@ session_ttl = "1h" allowed_origins = ["http://localhost:9080", "http://127.0.0.1:9080"] require_protocol_version_header = true -# [rate_limit] -# mode = "local" # local, distributed -# max_local_buckets = 10000 +[rate_limit] +# hybrid: per-node DashMap fast path, PG fallback for global keys (DDoS-grade). +# strict: every check round-trips to PG (cluster-wide correct, billing-grade). +mode = "hybrid" +max_local_buckets = 10000 +# Per-handler quotas live on the function macros, e.g. +# #[forge::mutation(rate_limit_requests = 10, rate_limit_per_secs = 60, rate_limit_key = "ip")] + +[signals] +# Product analytics + diagnostics are off by default; this demo opts in to +# exercise the /_api/signal endpoint and the client SDK. +enabled = true # [cluster] # name = "node-1" # auto-generated if omitted diff --git a/examples/with-svelte/demo/frontend/playwright.config.ts b/examples/with-svelte/demo/frontend/playwright.config.ts index 678e4e7b..0c5b1573 100644 --- a/examples/with-svelte/demo/frontend/playwright.config.ts +++ b/examples/with-svelte/demo/frontend/playwright.config.ts @@ -6,7 +6,7 @@ export default defineConfig({ testDir: "./tests", fullyParallel: false, forbidOnly: !!process.env.CI, - retries: process.env.CI ? 1 : 1, + retries: process.env.CI ? 2 : 0, timeout: 90_000, workers: process.env.CI ? 1 : undefined, reporter: "html", diff --git a/examples/with-svelte/demo/frontend/src/lib/forge/api.ts b/examples/with-svelte/demo/frontend/src/lib/forge/api.ts index ea248e6e..462a83ff 100644 --- a/examples/with-svelte/demo/frontend/src/lib/forge/api.ts +++ b/examples/with-svelte/demo/frontend/src/lib/forge/api.ts @@ -18,6 +18,7 @@ import type { RegisterInput, TokenPair, Trade, + TriggerDemoWebhookInput, User, UserRole, VerificationInput, @@ -70,6 +71,9 @@ export const refreshToken = (args: RefreshInput): Promise => getForgeClient().call("refresh_token", args); export const register = (args: RegisterInput): Promise => getForgeClient().call("register", args); +export const triggerDemoWebhook = ( + args: TriggerDemoWebhookInput, +): Promise => getForgeClient().call("trigger_demo_webhook", args); export const updateUser = (args: { id: string; email: string | null; diff --git a/examples/with-svelte/demo/frontend/src/lib/forge/reactive.svelte.ts b/examples/with-svelte/demo/frontend/src/lib/forge/reactive.svelte.ts index 520357f8..6f651f14 100644 --- a/examples/with-svelte/demo/frontend/src/lib/forge/reactive.svelte.ts +++ b/examples/with-svelte/demo/frontend/src/lib/forge/reactive.svelte.ts @@ -13,6 +13,7 @@ import { login, refreshToken, register, + triggerDemoWebhook, updateUser, } from "./api"; import { @@ -31,6 +32,7 @@ import type { RegisterInput, TokenPair, Trade, + TriggerDemoWebhookInput, User, UserRole, WebhookEvent, @@ -63,6 +65,10 @@ export const refreshToken$ = (): ReactiveMutation => toReactiveMutation(refreshToken); export const register$ = (): ReactiveMutation => toReactiveMutation(register); +export const triggerDemoWebhook$ = (): ReactiveMutation< + TriggerDemoWebhookInput, + boolean +> => toReactiveMutation(triggerDemoWebhook); export const updateUser$ = (): ReactiveMutation< { id: string; diff --git a/examples/with-svelte/demo/frontend/src/lib/forge/types.ts b/examples/with-svelte/demo/frontend/src/lib/forge/types.ts index 7a0560c9..d93cdd3a 100644 --- a/examples/with-svelte/demo/frontend/src/lib/forge/types.ts +++ b/examples/with-svelte/demo/frontend/src/lib/forge/types.ts @@ -7,11 +7,11 @@ export interface AuthResponse { } export interface BinanceTrade { - symbol: string; - price: string; - quantity: string; - trade_time: number; - is_buyer_maker: boolean; + s: string; + p: string; + q: string; + T: number; + m: boolean; } export interface ConfirmVerificationInput { @@ -90,6 +90,10 @@ export interface Trade { created_at: string; } +export interface TriggerDemoWebhookInput { + idempotency_key: string; +} + export interface User { id: string; email: string; @@ -97,7 +101,6 @@ export interface User { role: UserRole; created_at: string; updated_at: string; - password_hash?: string; } export interface UserPublic { diff --git a/examples/with-svelte/demo/frontend/src/routes/+page.svelte b/examples/with-svelte/demo/frontend/src/routes/+page.svelte index 7f6c5842..70a4e262 100644 --- a/examples/with-svelte/demo/frontend/src/routes/+page.svelte +++ b/examples/with-svelte/demo/frontend/src/routes/+page.svelte @@ -10,6 +10,7 @@ trackExportUsers, trackAccountVerification, confirmVerification, + triggerDemoWebhook, getUsers$, getIssLocation$, getTrades$, @@ -26,7 +27,8 @@ const signals = getForgeSignals(); const apiUrl = PUBLIC_API_URL; - const users = getUsers$(); + // `get_users` requires auth; only subscribe once logged in (avoids an + // anonymous 401 and a wasted SSE subscription). Created in the template. const issLocation = getIssLocation$(); const trades = getTrades$(); const webhookEvents = getWebhookEvents$(); @@ -55,8 +57,11 @@ // Auth form state (only form inputs and UI state are local) let authMode = $state<"login" | "register">("login"); - let authEmail = $state("demo@example.com"); - let authPassword = $state("password123"); + // Prefill credentials only when the SvelteKit build runs in dev mode (Vite `import.meta.env.DEV`). + // Production bundles ship with empty fields so leaked demos don't double as one-click logins. + const DEV_PREFILL = import.meta.env.DEV; + let authEmail = $state(DEV_PREFILL ? "demo@example.com" : ""); + let authPassword = $state(DEV_PREFILL ? "password123" : ""); let authName = $state(""); let authLoading = $state(false); let authError = $state(null); @@ -173,39 +178,15 @@ async function triggerWebhook() { signals.breadcrumb("Sending webhook"); webhookError = null; - const secret = "demo-secret"; - const payload = JSON.stringify({ action: "test", ts: Date.now() }); - - const encoder = new TextEncoder(); - const key = await crypto.subtle.importKey( - "raw", - encoder.encode(secret), - { name: "HMAC", hash: "SHA-256" }, - false, - ["sign"] - ); - const signature = await crypto.subtle.sign("HMAC", key, encoder.encode(payload)); - const signatureHex = Array.from(new Uint8Array(signature)) - .map((b) => b.toString(16).padStart(2, "0")) - .join(""); - - const res = await fetch(`${apiUrl}/_api/webhooks/demo`, { - method: "POST", - headers: { - "Content-Type": "application/json", - "X-Webhook-Signature": signatureHex, - "X-Webhook-Timestamp": Math.floor(Date.now() / 1000).toString(), - "X-Idempotency-Key": idempotencyKey, - }, - body: payload, - }); - - if (res.ok) { + // The HMAC secret lives on the server. The backend signs and POSTs the + // webhook to itself so the browser bundle never ships the secret. + try { + await triggerDemoWebhook({ idempotency_key: idempotencyKey }); keyUsed = true; signals.track("webhook_sent", { idempotency_key: idempotencyKey }); - } else { - webhookError = `Error: ${res.status}`; - signals.track("webhook_error", { status: res.status }); + } catch (err: unknown) { + webhookError = err instanceof Error ? err.message : String(err); + signals.track("webhook_error"); } } @@ -562,7 +543,9 @@ -
+ {#if auth.isAuthenticated} + {@const users = getUsers$()} +

Users crud + subscribe

@@ -612,7 +595,8 @@ {:else if !users.loading}

No users yet. Create one above.

{/if} -
+
+ {/if}