diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml index d8aaff7..6631d7a 100644 --- a/.github/workflows/docker-image.yml +++ b/.github/workflows/docker-image.yml @@ -7,6 +7,18 @@ on: - master tags: - "v*" + paths: + - "Dockerfile" + - "docker-entrypoint.sh" + - "config.docker.json" + - "go.mod" + - "go.sum" + - "cmd/**" + - "internal/**" + - "static/**" + - "frontend/**" + - ".dockerignore" + - ".github/workflows/docker-image.yml" workflow_dispatch: permissions: @@ -19,15 +31,39 @@ concurrency: env: REGISTRY: ghcr.io - IMAGE_NAME: ${{ github.repository }} jobs: - build-and-push: + prep: runs-on: ubuntu-latest + outputs: + image_name_lc: ${{ steps.norm.outputs.image_name_lc }} + steps: + - name: Normalize image name + id: norm + run: | + echo "image_name_lc=${GITHUB_REPOSITORY,,}" >> "$GITHUB_OUTPUT" + + build: + runs-on: ubuntu-latest + needs: prep + strategy: + fail-fast: false + matrix: + include: + - platform: linux/amd64 + arch: amd64 + - platform: linux/arm64 + arch: arm64 + steps: - name: Checkout uses: actions/checkout@v4 + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + with: + platforms: arm64 + - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 @@ -42,21 +78,84 @@ jobs: id: meta uses: docker/metadata-action@v5 with: - images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + images: ${{ env.REGISTRY }}/${{ needs.prep.outputs.image_name_lc }} tags: | type=ref,event=branch type=ref,event=tag type=sha,prefix=sha- type=raw,value=latest,enable={{is_default_branch}} - - name: Build and push image + - name: Build and push by digest + id: build uses: docker/build-push-action@v6 with: context: . file: ./Dockerfile - platforms: linux/amd64,linux/arm64 - push: true - tags: ${{ steps.meta.outputs.tags }} + platforms: ${{ matrix.platform }} labels: ${{ steps.meta.outputs.labels }} - cache-from: type=gha - cache-to: type=gha,mode=max + outputs: type=image,name=${{ env.REGISTRY }}/${{ needs.prep.outputs.image_name_lc }},push-by-digest=true,name-canonical=true,push=true + cache-from: | + type=gha,scope=notion2api-${{ matrix.arch }} + type=registry,ref=${{ env.REGISTRY }}/${{ needs.prep.outputs.image_name_lc }}:buildcache-${{ matrix.arch }} + cache-to: | + type=gha,mode=max,scope=notion2api-${{ matrix.arch }} + type=registry,ref=${{ env.REGISTRY }}/${{ needs.prep.outputs.image_name_lc }}:buildcache-${{ matrix.arch }},mode=max,oci-mediatypes=true,image-manifest=true + + - name: Export digest + run: | + mkdir -p "${{ runner.temp }}/digests" + digest="${{ steps.build.outputs.digest }}" + touch "${{ runner.temp }}/digests/${digest#sha256:}" + + - name: Upload digest artifact + uses: actions/upload-artifact@v4 + with: + name: digests-${{ matrix.arch }} + path: ${{ runner.temp }}/digests/* + if-no-files-found: error + retention-days: 1 + + merge: + runs-on: ubuntu-latest + needs: + - prep + - build + + steps: + - name: Download digest artifacts + uses: actions/download-artifact@v4 + with: + path: ${{ runner.temp }}/digests + pattern: digests-* + merge-multiple: true + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to GHCR + uses: docker/login-action@v3 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Extract Docker metadata + id: meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.REGISTRY }}/${{ needs.prep.outputs.image_name_lc }} + tags: | + type=ref,event=branch + type=ref,event=tag + type=sha,prefix=sha- + type=raw,value=latest,enable={{is_default_branch}} + + - name: Create and push manifest list + working-directory: ${{ runner.temp }}/digests + run: | + tags=$(jq -r '.tags | map("-t " + .) | join(" ")' <<< '${{ steps.meta.outputs.json }}') + sources=$(printf '${{ env.REGISTRY }}/${{ needs.prep.outputs.image_name_lc }}@sha256:%s ' *) + docker buildx imagetools create $tags $sources + + - name: Inspect image + run: docker buildx imagetools inspect ${{ env.REGISTRY }}/${{ needs.prep.outputs.image_name_lc }}:${{ steps.meta.outputs.version }} diff --git a/.gitignore b/.gitignore index 1abd7ba..30a6578 100644 --- a/.gitignore +++ b/.gitignore @@ -41,8 +41,4 @@ frontend/out/ *.spec.tsx __tests__/ WEBUI_DEVELOPMENT_GUIDE.md - -# Rust FFI build artifacts (v2 wreq-ffi) -wreq-ffi/target/ -wreq-ffi/include/wreq_ffi.h -wreq-ffi/Cargo.lock +.serena/ diff --git a/Dockerfile b/Dockerfile index a0321c7..9d32d66 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,89 +2,17 @@ FROM --platform=$BUILDPLATFORM node:22-bookworm AS frontend-builder WORKDIR /frontend COPY frontend/package.json frontend/package-lock.json ./ -RUN npm ci +RUN --mount=type=cache,target=/root/.npm,sharing=locked \ + npm ci COPY frontend ./ RUN npm run build -FROM --platform=$BUILDPLATFORM rust:1.86-bookworm AS rust-builder -ARG BUILDPLATFORM -ARG TARGETPLATFORM -ARG TARGETARCH -ARG TARGETOS=linux -WORKDIR /src - -RUN apt-get update -o Acquire::Retries=5 \ - && apt-get install -y -o Acquire::Retries=5 --no-install-recommends \ - cmake perl build-essential libclang-dev clang lld file \ - gcc-x86-64-linux-gnu g++-x86-64-linux-gnu \ - gcc-aarch64-linux-gnu g++-aarch64-linux-gnu \ - && rm -rf /var/lib/apt/lists/* - -RUN set -eux; \ - case "${TARGETARCH}" in \ - amd64) RUST_TARGET=x86_64-unknown-linux-gnu ;; \ - arm64) RUST_TARGET=aarch64-unknown-linux-gnu ;; \ - *) echo "unsupported TARGETARCH=${TARGETARCH}" >&2; exit 1 ;; \ - esac; \ - rustup target add "${RUST_TARGET}"; \ - echo "${RUST_TARGET}" > /tmp/rust_target - -ENV CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_LINKER=x86_64-linux-gnu-gcc \ - CC_x86_64_unknown_linux_gnu=x86_64-linux-gnu-gcc \ - CXX_x86_64_unknown_linux_gnu=x86_64-linux-gnu-g++ \ - AR_x86_64_unknown_linux_gnu=x86_64-linux-gnu-ar \ - CARGO_TARGET_AARCH64_UNKNOWN_LINUX_GNU_LINKER=aarch64-linux-gnu-gcc \ - CC_aarch64_unknown_linux_gnu=aarch64-linux-gnu-gcc \ - CXX_aarch64_unknown_linux_gnu=aarch64-linux-gnu-g++ \ - AR_aarch64_unknown_linux_gnu=aarch64-linux-gnu-ar - -ENV CARGO_TARGET_DIR=/cargo-target - -COPY wreq-ffi ./wreq-ffi - -RUN --mount=type=cache,target=/usr/local/cargo/registry \ - --mount=type=cache,target=/usr/local/cargo/git \ - --mount=type=cache,target=/cargo-target,id=cargo-target-${TARGETARCH},sharing=locked \ - set -eux; \ - RUST_TARGET=$(cat /tmp/rust_target); \ - case "${TARGETARCH}" in \ - amd64) CC=x86_64-linux-gnu-gcc; CXX=x86_64-linux-gnu-g++; AR=x86_64-linux-gnu-ar ;; \ - arm64) CC=aarch64-linux-gnu-gcc; CXX=aarch64-linux-gnu-g++; AR=aarch64-linux-gnu-ar ;; \ - *) echo "unsupported TARGETARCH=${TARGETARCH}" >&2; exit 1 ;; \ - esac; \ - export CC CXX AR; \ - echo "rust-builder toolchain: TARGETARCH=${TARGETARCH} RUST_TARGET=${RUST_TARGET} CC=${CC} CXX=${CXX} AR=${AR}"; \ - echo "rust-builder diag: BUILDPLATFORM=${BUILDPLATFORM} TARGETPLATFORM=${TARGETPLATFORM} TARGETARCH=${TARGETARCH} RUST_TARGET=${RUST_TARGET} host=$(uname -m)"; \ - cd wreq-ffi; \ - mkdir -p include; \ - touch src/lib.rs; \ - cargo build --release --target "${RUST_TARGET}"; \ - test -f include/wreq_ffi.h; \ - mkdir -p /out; \ - cp "${CARGO_TARGET_DIR}/${RUST_TARGET}/release/libwreq_ffi.a" /out/; \ - cp include/wreq_ffi.h /out/; \ - FIRST_MEMBER=$(ar t /out/libwreq_ffi.a | head -1); \ - AFILE=$(ar p /out/libwreq_ffi.a "$FIRST_MEMBER" | file -); \ - echo "rust-builder: first member ($FIRST_MEMBER) of /out/libwreq_ffi.a => ${AFILE}"; \ - case "${TARGETARCH}" in \ - amd64) echo "${AFILE}" | grep -q 'x86-64' || { echo "FATAL: /out/libwreq_ffi.a is not x86-64 (TARGETARCH=amd64). This usually means a cache mount got mixed up; try: docker buildx prune -af" >&2; exit 1; } ;; \ - arm64) echo "${AFILE}" | grep -q 'aarch64' || { echo "FATAL: /out/libwreq_ffi.a is not aarch64 (TARGETARCH=arm64). This usually means a cache mount got mixed up; try: docker buildx prune -af" >&2; exit 1; } ;; \ - esac; \ - echo "rust-builder: arch verified for TARGETARCH=${TARGETARCH}" - -FROM --platform=$BUILDPLATFORM golang:1.22-bookworm AS builder +FROM --platform=$BUILDPLATFORM golang:1.25.0-bookworm AS builder ARG BUILDPLATFORM ARG TARGETPLATFORM ARG TARGETOS ARG TARGETARCH -RUN apt-get update -o Acquire::Retries=5 \ - && apt-get install -y -o Acquire::Retries=5 --no-install-recommends \ - file \ - gcc-x86-64-linux-gnu g++-x86-64-linux-gnu \ - gcc-aarch64-linux-gnu g++-aarch64-linux-gnu \ - && rm -rf /var/lib/apt/lists/* - WORKDIR /src COPY go.mod go.sum ./ RUN --mount=type=cache,target=/go/pkg/mod \ @@ -95,48 +23,23 @@ COPY cmd ./cmd COPY internal ./internal COPY static ./static COPY --from=frontend-builder /frontend/out /src/static/admin -COPY --from=rust-builder /out/libwreq_ffi.a /src/wreq-ffi/target/release/libwreq_ffi.a -COPY --from=rust-builder /out/wreq_ffi.h /src/wreq-ffi/include/wreq_ffi.h RUN --mount=type=cache,target=/go/pkg/mod \ --mount=type=cache,target=/root/.cache/go-build \ set -eux; \ - case "${TARGETARCH}" in \ - amd64) CC=x86_64-linux-gnu-gcc; CXX=x86_64-linux-gnu-g++ ;; \ - arm64) CC=aarch64-linux-gnu-gcc; CXX=aarch64-linux-gnu-g++ ;; \ - *) echo "unsupported TARGETARCH=${TARGETARCH}" >&2; exit 1 ;; \ - esac; \ - echo "go-builder diag: BUILDPLATFORM=${BUILDPLATFORM} TARGETPLATFORM=${TARGETPLATFORM} TARGETARCH=${TARGETARCH} CC=${CC} host=$(uname -m)"; \ - FIRST_MEMBER=$(ar t /src/wreq-ffi/target/release/libwreq_ffi.a | head -1); \ - AFILE=$(ar p /src/wreq-ffi/target/release/libwreq_ffi.a "$FIRST_MEMBER" | file -); \ - echo "go-builder: first member ($FIRST_MEMBER) of libwreq_ffi.a => ${AFILE}"; \ - case "${TARGETARCH}" in \ - amd64) echo "${AFILE}" | grep -q 'x86-64' || { echo "FATAL: libwreq_ffi.a in builder stage is not x86-64; rust-builder produced wrong arch or COPY layer is stale. Run: docker buildx prune -af" >&2; exit 1; } ;; \ - arm64) echo "${AFILE}" | grep -q 'aarch64' || { echo "FATAL: libwreq_ffi.a in builder stage is not aarch64; rust-builder produced wrong arch or COPY layer is stale. Run: docker buildx prune -af" >&2; exit 1; } ;; \ - esac; \ - test -f ./cmd/notion2api/main.go; \ - CGO_ENABLED=1 GOOS=${TARGETOS} GOARCH=${TARGETARCH} CC=${CC} CXX=${CXX} \ - go build -v -trimpath -tags wreq_ffi \ + CGO_ENABLED=0 GOOS=${TARGETOS} GOARCH=${TARGETARCH} \ + go build -v -trimpath \ -ldflags="-s -w" \ -o /out/notion2api ./cmd/notion2api -FROM node:22-bookworm-slim +FROM alpine:3.22 +ARG TARGETARCH ENV TZ=Asia/Shanghai -ENV NODE_PATH=/opt/notion2api-helper/node_modules -ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin WORKDIR /app -RUN apt-get update \ - && apt-get install -y --no-install-recommends ca-certificates tzdata curl tini \ - && rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*.deb \ - && mkdir -p /opt/notion2api-helper /app/config /app/data/notion_accounts /app/static - -RUN cd /opt/notion2api-helper \ - && npm init -y >/dev/null 2>&1 \ - && npm install --omit=dev --no-package-lock node-wreq@2.2.1 \ - && test -d "$NODE_PATH/node-wreq" \ - && npm cache clean --force >/dev/null 2>&1 +RUN apk add --no-cache ca-certificates tzdata curl tini \ + && mkdir -p /app/config /app/data/notion_accounts /app/static COPY --from=builder /out/notion2api /app/notion2api COPY --from=builder /src/static /app/static @@ -150,5 +53,5 @@ EXPOSE 8787 HEALTHCHECK --interval=30s --timeout=5s --start-period=20s --retries=3 CMD curl -fsS http://127.0.0.1:8787/healthz || exit 1 -ENTRYPOINT ["tini", "--", "docker-entrypoint.sh"] +ENTRYPOINT ["/sbin/tini", "--", "docker-entrypoint.sh"] CMD ["./notion2api", "--config", "/app/config/config.json"] diff --git a/README.md b/README.md index a789cf4..07df53f 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,8 @@ docker compose up -d --build docker compose -f docker-compose.prod.yml up -d --build ``` +本地从源码开发需 Go `1.25.0+`(`go.mod` 已声明)。 + ## 默认入口 - API:`http://127.0.0.1:8787/v1/*` diff --git a/cmd/notion2api/main.go b/cmd/notion2api/main.go index 0ebdd15..15912dd 100644 --- a/cmd/notion2api/main.go +++ b/cmd/notion2api/main.go @@ -1,13 +1,9 @@ package main import ( - "log" - "notion2api/internal/app" - "notion2api/internal/wreq" ) func main() { - log.Printf("notion2api: wreq backend = %s", wreq.Version()) app.Main() } diff --git a/config.docker.json b/config.docker.json index 9bf604d..c18905f 100644 --- a/config.docker.json +++ b/config.docker.json @@ -29,6 +29,9 @@ "persist_continuation_sessions": true, "persist_sillytavern_bindings": true }, + "limits": { + "max_request_body_bytes": 4194304 + }, "admin": { "enabled": true, "password": "change-me-admin-password", @@ -46,6 +49,16 @@ "retry_on_auth_error": true, "auto_switch_account": true }, + "dispatch": { + "probe_cache_ttl_seconds": 45 + }, + "browser": { + "helper_pool_size": 0 + }, + "debug": { + "pprof_enabled": false, + "pprof_addr": "127.0.0.1:6060" + }, "features": { "use_web_search": true, "use_read_only_mode": false, diff --git a/config.example.json b/config.example.json index aebb2b4..d7651ea 100644 --- a/config.example.json +++ b/config.example.json @@ -43,6 +43,9 @@ "persist_continuation_sessions": true, "persist_sillytavern_bindings": true }, + "limits": { + "max_request_body_bytes": 4194304 + }, "features": { "use_web_search": true, "use_read_only_mode": false, @@ -69,6 +72,16 @@ "retry_on_auth_error": true, "auto_switch_account": true }, + "dispatch": { + "probe_cache_ttl_seconds": 45 + }, + "browser": { + "helper_pool_size": 0 + }, + "debug": { + "pprof_enabled": false, + "pprof_addr": "127.0.0.1:6060" + }, "accounts": [ { "email": "alice@example.com", diff --git a/go.mod b/go.mod index fca8103..8c7a871 100644 --- a/go.mod +++ b/go.mod @@ -1,17 +1,35 @@ module notion2api -go 1.22.0 +go 1.25.0 require modernc.org/sqlite v1.33.1 +require ( + github.com/andybalholm/brotli v1.2.1 // indirect + github.com/enetx/g v1.0.224 // indirect + github.com/enetx/http v1.0.28 // indirect + github.com/enetx/http2 v1.0.26 // indirect + github.com/enetx/http3 v1.0.7 // indirect + github.com/enetx/iter v0.0.0-20250912135656-f1583323588f // indirect + github.com/klauspost/compress v1.18.5 // indirect + github.com/quic-go/qpack v0.6.0 // indirect + github.com/quic-go/quic-go v0.59.0 // indirect + github.com/refraction-networking/utls v1.8.3-0.20260301010127-aa6edf4b11af // indirect + github.com/wzshiming/socks5 v0.7.0 // indirect + golang.org/x/crypto v0.41.0 // indirect + golang.org/x/net v0.43.0 // indirect + golang.org/x/text v0.35.0 // indirect +) + require ( github.com/dustin/go-humanize v1.0.1 // indirect + github.com/enetx/surf v1.0.199 github.com/google/uuid v1.6.0 // indirect github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect - golang.org/x/sys v0.22.0 // indirect + golang.org/x/sys v0.35.0 // indirect modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 // indirect modernc.org/libc v1.55.3 // indirect modernc.org/mathutil v1.6.0 // indirect diff --git a/go.sum b/go.sum index 617fda4..c9ac762 100644 --- a/go.sum +++ b/go.sum @@ -1,26 +1,68 @@ +github.com/andybalholm/brotli v1.2.1 h1:R+f5xP285VArJDRgowrfb9DqL18yVK0gKAW/F+eTWro= +github.com/andybalholm/brotli v1.2.1/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/enetx/g v1.0.224 h1:H/uonguFE4qG8YCn5bSpZX5Wh+wTSb+jgf3I2ZM25XM= +github.com/enetx/g v1.0.224/go.mod h1:lxhby3LjP8jOTGbxJ/PCd+2Zq1gYiSBbtL/llPhAg5c= +github.com/enetx/http v1.0.28 h1:IaNSSDFlAVVdHnYhNIR9wAN7GY4TWL/kkvYC3jOaueY= +github.com/enetx/http v1.0.28/go.mod h1:1f4mytfF/SfjATEJnynpwGS6aa1ALjb8DtmYgFVblY0= +github.com/enetx/http2 v1.0.26 h1:wy3lYGVwnIUY4Q+gyPPQCJ1a+BMXD1B7Unpyc/Csrxc= +github.com/enetx/http2 v1.0.26/go.mod h1:t54ex5HIS8V1+2j6cvEOv6umlrHsbUPFKQ54nYB58Nk= +github.com/enetx/http3 v1.0.7 h1:daFhveKBtv8rRallCjaHErzzSHIrq07ovoSvVkvhcMM= +github.com/enetx/http3 v1.0.7/go.mod h1:sqpVGZ9F1/wCiW6sjBUS2errKAh3SUYn6VlWE7LL6KM= +github.com/enetx/iter v0.0.0-20250912135656-f1583323588f h1:GUW+4AWfECIEJ9oAxgEAVGCpaozMCjRiUYnuR6Q0bCQ= +github.com/enetx/iter v0.0.0-20250912135656-f1583323588f/go.mod h1:oMZN8hGLUpi7QBlMEUqailocNy0NFAO/7Lu+Nwh9HMM= +github.com/enetx/surf v1.0.199 h1:RtqcwlyLM8O4U+43laNnNJwx5hALkH5cJRxDX1F2VjM= +github.com/enetx/surf v1.0.199/go.mod h1:c6g53gi273RBiZFO4THWIqpn5n9RLC6vw5WpUwHrT4U= github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo= github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= +github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE= +github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8= +github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII= +github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw= +github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU= +github.com/refraction-networking/utls v1.8.3-0.20260301010127-aa6edf4b11af h1:er2acxbi3N1nvEq6HXHUAR1nTWEJmQfqiGR8EVT9rfs= +github.com/refraction-networking/utls v1.8.3-0.20260301010127-aa6edf4b11af/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= -golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic= -golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/wzshiming/socks5 v0.7.0 h1:euJ+U48WrvVngi+opC8vAnpZ5sK12y1C2hPvb1f48Rg= +github.com/wzshiming/socks5 v0.7.0/go.mod h1:BvCAqlzocQN5xwLjBZDBbvWlrx8sCYSSbHEOf2wZgT0= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +go.uber.org/mock v0.5.2 h1:LbtPTcP8A5k9WPXj54PPPbjcI4Y6lhyOZXn+VS7wNko= +go.uber.org/mock v0.5.2/go.mod h1:wLlUxC2vVTPTaE3UD51E0BGOAElKrILxhVSDYQLld5o= +golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= +golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc= +golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= +golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= +golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= +golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= -golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw= -golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= +golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= +golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= +golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= modernc.org/cc/v4 v4.21.4 h1:3Be/Rdo1fpr8GrQ7IVw9OHtplU4gWbb+wNgeoBMmGLQ= modernc.org/cc/v4 v4.21.4/go.mod h1:HM7VJTZbUCR3rV8EYBi9wxnJ0ZBRiGE5OeGXNA0IsLQ= modernc.org/ccgo/v4 v4.19.2 h1:lwQZgvboKD0jBwdaeVCTouxhxAyN6iawF3STraAal8Y= diff --git a/internal/app/account_discovery.go b/internal/app/account_discovery.go index b6a8636..e33dada 100644 --- a/internal/app/account_discovery.go +++ b/internal/app/account_discovery.go @@ -165,7 +165,7 @@ func discoverImportedAccountMetadata(ctx context.Context, cfg AppConfig, account } upstream := cfg.NotionUpstream() resolver := NewProxyResolver(cfg) - session, err := newNotionLoginSession(helperTimeout(cfg), upstream, resolver, accountEmail) + session, err := newNotionLoginSession(helperTimeout(cfg), upstream, resolver, accountEmail, cfg) if err != nil { return meta, err } diff --git a/internal/app/account_pool.go b/internal/app/account_pool.go index a287faf..10fcdb2 100644 --- a/internal/app/account_pool.go +++ b/internal/app/account_pool.go @@ -156,8 +156,10 @@ func sortDispatchCandidates(cfg AppConfig, accounts []NotionAccount, now time.Ti sort.Slice(accounts, func(i, j int) bool { left := accounts[i] right := accounts[j] - leftActive := canonicalEmailKey(left.Email) == activeKey - rightActive := canonicalEmailKey(right.Email) == activeKey + leftKey := getAccountEmailKey(left) + rightKey := getAccountEmailKey(right) + leftActive := leftKey == activeKey + rightActive := rightKey == activeKey if leftActive != rightActive { return leftActive } @@ -183,11 +185,11 @@ func sortDispatchCandidates(cfg AppConfig, accounts []NotionAccount, now time.Ti if !leftUsed.Equal(rightUsed) { return leftUsed.Before(rightUsed) } - return canonicalEmailKey(left.Email) < canonicalEmailKey(right.Email) + return leftKey < rightKey }) } -func pickDispatchCandidates(cfg AppConfig, now time.Time) []NotionAccount { +func buildDispatchCandidateOrder(cfg AppConfig, now time.Time) []NotionAccount { candidates := make([]NotionAccount, 0, len(cfg.Accounts)) for _, account := range cfg.Accounts { account = ensureAccountPaths(cfg, account) @@ -199,6 +201,16 @@ func pickDispatchCandidates(cfg AppConfig, now time.Time) []NotionAccount { return candidates } +func pickDispatchCandidatesFromSnapshot(bundle *snapshotBundle, now time.Time) []NotionAccount { + if bundle == nil { + return nil + } + if len(bundle.DispatchOrder) > 0 { + return bundle.DispatchOrder + } + return buildDispatchCandidateOrder(bundle.Config, now) +} + func applyAccountUpdate(cfg AppConfig, account NotionAccount, makeActive bool) AppConfig { account = ensureAccountPaths(cfg, account) cfg.UpsertAccount(account) @@ -241,8 +253,10 @@ func (a *App) runPromptWithSession(ctx context.Context, cfg AppConfig, session S if a.runPromptWithSessionOverride != nil { return a.runPromptWithSessionOverride(ctx, cfg, session, request, onDelta) } + transportClientNewTotalMetric.Add("standard", 1) client := newNotionAIClient(session, cfg, accountEmail) if onDelta != nil { + transportClientNewTotalMetric.Add("streaming", 1) client = newNotionAIStreamingClient(session, cfg, accountEmail) } execute := func(ctx context.Context, current PromptRunRequest, forward func(string) error) (InferenceResult, error) { @@ -261,8 +275,10 @@ func (a *App) runPromptWithSessionWithSink(ctx context.Context, cfg AppConfig, s if a.runPromptWithSessionOverride != nil { return a.runPromptWithSessionOverride(ctx, cfg, session, request, sink.Text) } + transportClientNewTotalMetric.Add("streaming", 1) client := newNotionAIStreamingClient(session, cfg, accountEmail) if sink.Text == nil && sink.Reasoning == nil && sink.ReasoningWarmup == nil && sink.KeepAlive == nil { + transportClientNewTotalMetric.Add("standard", 1) client = newNotionAIClient(session, cfg, accountEmail) } if sink.Reasoning != nil || sink.ReasoningWarmup != nil || sink.KeepAlive != nil { diff --git a/internal/app/accounts.go b/internal/app/accounts.go index a1c7331..021a2a6 100644 --- a/internal/app/accounts.go +++ b/internal/app/accounts.go @@ -10,6 +10,11 @@ import ( "strings" ) +var ( + accountPathSlugPattern = regexp.MustCompile(`[^a-z0-9]+`) + windowsAbsolutePathPattern = regexp.MustCompile(`^[A-Za-z]:[\\/].*`) +) + type ResolvedLoginHelper struct { SessionsDir string `json:"sessions_dir"` TimeoutSec int `json:"timeout_sec"` @@ -50,13 +55,19 @@ func canonicalEmailKey(email string) string { return strings.ToLower(strings.TrimSpace(email)) } +func getAccountEmailKey(account NotionAccount) string { + if account.emailKey != "" { + return account.emailKey + } + return canonicalEmailKey(account.Email) +} + func accountPathSlug(email string) string { clean := canonicalEmailKey(email) if clean == "" { return "account" } - re := regexp.MustCompile(`[^a-z0-9]+`) - clean = re.ReplaceAllString(clean, "_") + clean = accountPathSlugPattern.ReplaceAllString(clean, "_") clean = strings.Trim(clean, "_") if clean == "" { return "account" @@ -99,7 +110,7 @@ func pathLooksAbsoluteAnyOS(value string) bool { if filepath.IsAbs(clean) { return true } - if matched, _ := regexp.MatchString(`^[A-Za-z]:[\\/].*`, clean); matched { + if windowsAbsolutePathPattern.MatchString(clean) { return true } if strings.HasPrefix(clean, `\\`) { @@ -119,7 +130,7 @@ func isForeignAbsolutePath(value string) bool { if runtime.GOOS == "windows" { return strings.HasPrefix(clean, "/") } - if matched, _ := regexp.MatchString(`^[A-Za-z]:[\\/].*`, clean); matched { + if windowsAbsolutePathPattern.MatchString(clean) { return true } if strings.HasPrefix(clean, `\\`) { @@ -150,7 +161,7 @@ func (cfg AppConfig) FindAccount(email string) (NotionAccount, int, bool) { return NotionAccount{}, -1, false } for i, account := range cfg.Accounts { - if canonicalEmailKey(account.Email) == target { + if getAccountEmailKey(account) == target { return account, i, true } } @@ -215,6 +226,7 @@ func (helper ResolvedLoginHelper) ProbePath(profileDir string) string { } func ensureAccountPaths(cfg AppConfig, account NotionAccount) NotionAccount { + account.emailKey = canonicalEmailKey(account.Email) helper := cfg.ResolveLoginHelper() if strings.TrimSpace(account.ProfileDir) == "" || isForeignAbsolutePath(account.ProfileDir) { account.ProfileDir = helper.ProfileDirFor(account.Email) @@ -328,12 +340,13 @@ func (cfg *AppConfig) UpsertAccount(account NotionAccount) (NotionAccount, int) } func (cfg *AppConfig) DeleteAccount(email string) bool { - _, index, ok := cfg.FindAccount(email) + target := canonicalEmailKey(email) + _, index, ok := cfg.FindAccount(target) if !ok { return false } cfg.Accounts = append(cfg.Accounts[:index], cfg.Accounts[index+1:]...) - if canonicalEmailKey(cfg.ActiveAccount) == canonicalEmailKey(email) { + if canonicalEmailKey(cfg.ActiveAccount) == target { cfg.ActiveAccount = "" cfg.ProbeJSON = "" } diff --git a/internal/app/admin.go b/internal/app/admin.go index 6888344..19f8ea0 100644 --- a/internal/app/admin.go +++ b/internal/app/admin.go @@ -426,9 +426,9 @@ func (a *App) handleAdminLogin(w http.ResponseWriter, r *http.Request) { }) return } - payload, err := decodeBody(r) + payload, err := a.decodeBody(w, r) if err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + writeInvalidBodyError(w, err) return } if !securePasswordEqual(password, stringValue(payload["password"])) { @@ -671,9 +671,9 @@ func (a *App) handleAdminTest(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusMethodNotAllowed, map[string]any{"detail": "method not allowed"}) return } - payload, err := decodeBody(r) + payload, err := a.decodeBody(w, r) if err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + writeInvalidBodyError(w, err) return } cfg, _, registry := a.State.Snapshot() diff --git a/internal/app/admin_accounts.go b/internal/app/admin_accounts.go index 279e811..32a9127 100644 --- a/internal/app/admin_accounts.go +++ b/internal/app/admin_accounts.go @@ -87,7 +87,7 @@ func (a *App) accountRuntimeSummary(cfg AppConfig, account NotionAccount) map[st "consecutive_failures": account.ConsecutiveFailures, "total_successes": account.TotalSuccesses, "total_failures": account.TotalFailures, - "active": canonicalEmailKey(cfg.ActiveAccount) == canonicalEmailKey(account.Email), + "active": canonicalEmailKey(cfg.ActiveAccount) == getAccountEmailKey(account), } if status, err := readLoginStatusFile(account.PendingStatePath); err == nil { item["login_status"] = status @@ -262,9 +262,9 @@ func (a *App) handleAdminAccounts(w http.ResponseWriter, r *http.Request) { case http.MethodGet: writeJSON(w, http.StatusOK, a.buildAccountsPayload()) case http.MethodPost: - payload, err := decodeBody(r) + payload, err := a.decodeBody(w, r) if err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + writeInvalidBodyError(w, err) return } account, makeActive, err := decodeAccountPayload(payload) @@ -286,11 +286,12 @@ func (a *App) handleAdminAccounts(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) return } + a.invalidateDispatchProbeCache() writeJSON(w, http.StatusOK, a.buildAccountsPayload()) case http.MethodPut: - payload, err := decodeBody(r) + payload, err := a.decodeBody(w, r) if err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + writeInvalidBodyError(w, err) return } email := accountEmailFromPayload(payload) @@ -314,7 +315,7 @@ func (a *App) handleAdminAccounts(w http.ResponseWriter, r *http.Request) { return } cfg.Accounts[index] = ensureAccountPaths(cfg, next) - if canonicalEmailKey(cfg.ActiveAccount) == canonicalEmailKey(next.Email) && next.Disabled { + if canonicalEmailKey(cfg.ActiveAccount) == getAccountEmailKey(next) && next.Disabled { cfg.ActiveAccount = "" cfg.ProbeJSON = "" } @@ -330,6 +331,7 @@ func (a *App) handleAdminAccounts(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) return } + a.invalidateDispatchProbeCache() writeJSON(w, http.StatusOK, a.buildAccountsPayload()) default: writeJSON(w, http.StatusMethodNotAllowed, map[string]any{"detail": "method not allowed"}) @@ -362,6 +364,7 @@ func (a *App) handleAdminAccountDelete(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) return } + a.invalidateDispatchProbeCache() writeJSON(w, http.StatusOK, a.buildAccountsPayload()) } @@ -373,9 +376,9 @@ func (a *App) handleAdminAccountsActivate(w http.ResponseWriter, r *http.Request writeJSON(w, http.StatusMethodNotAllowed, map[string]any{"detail": "method not allowed"}) return } - payload, err := decodeBody(r) + payload, err := a.decodeBody(w, r) if err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + writeInvalidBodyError(w, err) return } email := strings.TrimSpace(stringValue(payload["email"])) @@ -400,6 +403,7 @@ func (a *App) handleAdminAccountsActivate(w http.ResponseWriter, r *http.Request writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) return } + a.invalidateDispatchProbeCache() writeJSON(w, http.StatusOK, a.buildAccountsPayload()) } @@ -411,9 +415,9 @@ func (a *App) handleAdminAccountsTest(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusMethodNotAllowed, map[string]any{"detail": "method not allowed"}) return } - payload, err := decodeBody(r) + payload, err := a.decodeBody(w, r) if err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + writeInvalidBodyError(w, err) return } cfg, _, registry := a.State.Snapshot() @@ -676,9 +680,9 @@ func (a *App) handleAdminAccountManualImport(w http.ResponseWriter, r *http.Requ writeJSON(w, http.StatusMethodNotAllowed, map[string]any{"detail": "method not allowed"}) return } - payload, err := decodeBody(r) + payload, err := a.decodeBody(w, r) if err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + writeInvalidBodyError(w, err) return } req, err := decodeManualImportRequest(payload) @@ -750,6 +754,7 @@ func (a *App) handleAdminAccountManualImport(w http.ResponseWriter, r *http.Requ writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) return } + a.invalidateDispatchProbeCache() cfg, _, _ = a.State.Snapshot() account, _, _ = cfg.FindAccount(accountEmail) writeJSON(w, http.StatusOK, map[string]any{ @@ -767,9 +772,9 @@ func (a *App) handleAdminAccountLoginStart(w http.ResponseWriter, r *http.Reques writeJSON(w, http.StatusMethodNotAllowed, map[string]any{"detail": "method not allowed"}) return } - payload, err := decodeBody(r) + payload, err := a.decodeBody(w, r) if err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + writeInvalidBodyError(w, err) return } email := strings.TrimSpace(stringValue(payload["email"])) @@ -819,6 +824,7 @@ func (a *App) handleAdminAccountLoginStart(w http.ResponseWriter, r *http.Reques writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) return } + a.invalidateDispatchProbeCache() writeJSON(w, http.StatusOK, map[string]any{ "success": true, "account": a.accountRuntimeSummary(cfg, account), @@ -834,9 +840,9 @@ func (a *App) handleAdminAccountLoginVerify(w http.ResponseWriter, r *http.Reque writeJSON(w, http.StatusMethodNotAllowed, map[string]any{"detail": "method not allowed"}) return } - payload, err := decodeBody(r) + payload, err := a.decodeBody(w, r) if err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + writeInvalidBodyError(w, err) return } email := strings.TrimSpace(stringValue(payload["email"])) @@ -893,6 +899,7 @@ func (a *App) handleAdminAccountLoginVerify(w http.ResponseWriter, r *http.Reque writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) return } + a.invalidateDispatchProbeCache() writeJSON(w, http.StatusOK, map[string]any{ "success": true, "account": a.accountRuntimeSummary(cfg, account), diff --git a/internal/app/admin_conversations.go b/internal/app/admin_conversations.go index fe1721c..da94b3f 100644 --- a/internal/app/admin_conversations.go +++ b/internal/app/admin_conversations.go @@ -334,7 +334,7 @@ func (a *App) handleAdminEvents(w http.ResponseWriter, r *http.Request) { w.Header().Set("Cache-Control", "no-cache, no-transform") w.Header().Set("Connection", "keep-alive") w.Header().Set("X-Accel-Buffering", "no") - w.Header().Set("Access-Control-Allow-Origin", "*") + applyCORSHeaders(w) w.WriteHeader(http.StatusOK) subID, events := a.State.conversations().Subscribe() diff --git a/internal/app/config.go b/internal/app/config.go index d54b314..7c9eeec 100644 --- a/internal/app/config.go +++ b/internal/app/config.go @@ -22,6 +22,7 @@ type FeatureConfig struct { UseReadOnlyMode bool `json:"use_read_only_mode"` ForceDisableUpstreamEdits bool `json:"force_disable_upstream_edits"` ForceFreshThreadPerRequest bool `json:"force_fresh_thread_per_request"` + UseSurfHelperTransport bool `json:"use_surf_helper_transport,omitempty"` WriterMode bool `json:"writer_mode"` EnableGenerateImage bool `json:"enable_generate_image"` EnableCsvAttachmentSupport bool `json:"enable_csv_attachment_support"` @@ -47,6 +48,19 @@ type SessionRefreshConfig struct { AutoSwitch bool `json:"auto_switch_account"` } +type DispatchConfig struct { + ProbeCacheTTLSeconds int `json:"probe_cache_ttl_seconds,omitempty"` +} + +type BrowserConfig struct { + HelperPoolSize int `json:"helper_pool_size,omitempty"` +} + +type DebugConfig struct { + PprofEnabled bool `json:"pprof_enabled"` + PprofAddr string `json:"pprof_addr,omitempty"` +} + type StorageConfig struct { SQLitePath string `json:"sqlite_path,omitempty"` PersistConversations bool `json:"persist_conversations"` @@ -56,6 +70,10 @@ type StorageConfig struct { PersistSillyTavernBindings *bool `json:"persist_sillytavern_bindings,omitempty"` } +type LimitsConfig struct { + MaxRequestBodyBytes int64 `json:"max_request_body_bytes,omitempty"` +} + type PromptConfig struct { Profile string `json:"profile,omitempty"` CustomPrefix string `json:"custom_prefix,omitempty"` @@ -67,10 +85,12 @@ type PromptConfig struct { CodingRetryPrefixes []string `json:"coding_retry_prefixes,omitempty"` GeneralRetryPrefixes []string `json:"general_retry_prefixes,omitempty"` DirectAnswerRetryPrefixes []string `json:"direct_answer_retry_prefixes,omitempty"` + precomputedAllRetryPrefixes []string `json:"-"` } type NotionAccount struct { Email string `json:"email"` + emailKey string `json:"-"` ProbeJSON string `json:"probe_json,omitempty"` ProfileDir string `json:"profile_dir,omitempty"` StorageStatePath string `json:"storage_state_path,omitempty"` @@ -153,10 +173,14 @@ type AppConfig struct { Admin AdminConfig `json:"admin"` Responses ResponsesConfig `json:"responses"` Storage StorageConfig `json:"storage"` + Limits LimitsConfig `json:"limits,omitempty"` Prompt PromptConfig `json:"prompt"` Features FeatureConfig `json:"features"` LoginHelper LoginHelperConfig `json:"login_helper"` SessionRefresh SessionRefreshConfig `json:"session_refresh"` + Dispatch DispatchConfig `json:"dispatch"` + Browser BrowserConfig `json:"browser,omitempty"` + Debug DebugConfig `json:"debug"` Accounts []NotionAccount `json:"accounts,omitempty"` Models []ModelDefinition `json:"models,omitempty"` ModelAliases map[string]string `json:"model_aliases,omitempty"` @@ -418,6 +442,9 @@ func defaultConfig() AppConfig { Storage: StorageConfig{ PersistConversations: true, }, + Limits: LimitsConfig{ + MaxRequestBodyBytes: 4 * 1024 * 1024, + }, Prompt: PromptConfig{ Profile: "cognitive_reframing", FallbackProfiles: []string{"toolbox_capability_expansion"}, @@ -440,11 +467,19 @@ func defaultConfig() AppConfig { RetryOnAuthError: true, AutoSwitch: true, }, + Dispatch: DispatchConfig{ + ProbeCacheTTLSeconds: 45, + }, + Debug: DebugConfig{ + PprofEnabled: false, + PprofAddr: "127.0.0.1:6060", + }, Features: FeatureConfig{ UseWebSearch: true, UseReadOnlyMode: false, ForceDisableUpstreamEdits: false, ForceFreshThreadPerRequest: false, + UseSurfHelperTransport: false, WriterMode: false, EnableGenerateImage: true, EnableCsvAttachmentSupport: true, @@ -500,6 +535,10 @@ func normalizeConfig(cfg AppConfig) AppConfig { if cfg.PollMaxRounds <= 0 { cfg.PollMaxRounds = 40 } + cfg.Debug.PprofAddr = strings.TrimSpace(cfg.Debug.PprofAddr) + if cfg.Debug.PprofAddr == "" { + cfg.Debug.PprofAddr = "127.0.0.1:6060" + } if cfg.StreamChunkRunes <= 0 { cfg.StreamChunkRunes = 24 } @@ -512,6 +551,9 @@ func normalizeConfig(cfg AppConfig) AppConfig { if cfg.Responses.StoreTTLSeconds <= 0 { cfg.Responses.StoreTTLSeconds = 3600 } + if cfg.Limits.MaxRequestBodyBytes <= 0 { + cfg.Limits.MaxRequestBodyBytes = 4 * 1024 * 1024 + } cfg.Prompt.Profile = strings.TrimSpace(cfg.Prompt.Profile) if cfg.Prompt.Profile == "" { cfg.Prompt.Profile = "cognitive_reframing" @@ -535,6 +577,7 @@ func normalizeConfig(cfg AppConfig) AppConfig { cfg.Prompt.CodingRetryPrefixes = normalizePromptTextList(cfg.Prompt.CodingRetryPrefixes) cfg.Prompt.GeneralRetryPrefixes = normalizePromptTextList(cfg.Prompt.GeneralRetryPrefixes) cfg.Prompt.DirectAnswerRetryPrefixes = normalizePromptTextList(cfg.Prompt.DirectAnswerRetryPrefixes) + cfg.Prompt.precomputedAllRetryPrefixes = buildPromptGuardAllRetryPrefixes(cfg.Prompt) cfg.Storage.SQLitePath = strings.TrimSpace(cfg.Storage.SQLitePath) if cfg.Storage.SQLitePath == "" && strings.TrimSpace(cfg.ConfigPath) != "" { cfg.Storage.SQLitePath = "data/notion2api.sqlite" @@ -548,6 +591,15 @@ func normalizeConfig(cfg AppConfig) AppConfig { if cfg.SessionRefresh.IntervalSec <= 0 { cfg.SessionRefresh.IntervalSec = 900 } + if cfg.Dispatch.ProbeCacheTTLSeconds < 0 { + cfg.Dispatch.ProbeCacheTTLSeconds = 0 + } + if cfg.Browser.HelperPoolSize < 0 { + cfg.Browser.HelperPoolSize = 0 + } + if cfg.Browser.HelperPoolSize > 8 { + cfg.Browser.HelperPoolSize = 8 + } cfg.Features.SearchScopes = normalizeStringList(cfg.Features.SearchScopes) cfg.Features.AISurface = strings.TrimSpace(cfg.Features.AISurface) if cfg.Features.AISurface == "" { @@ -566,6 +618,7 @@ func normalizeConfig(cfg AppConfig) AppConfig { cfg.ActiveAccount = strings.TrimSpace(cfg.ActiveAccount) for i := range cfg.Accounts { cfg.Accounts[i].Email = strings.TrimSpace(cfg.Accounts[i].Email) + cfg.Accounts[i].emailKey = canonicalEmailKey(cfg.Accounts[i].Email) cfg.Accounts[i].ProbeJSON = strings.TrimSpace(cfg.Accounts[i].ProbeJSON) cfg.Accounts[i].ProfileDir = strings.TrimSpace(cfg.Accounts[i].ProfileDir) cfg.Accounts[i].StorageStatePath = strings.TrimSpace(cfg.Accounts[i].StorageStatePath) @@ -797,6 +850,9 @@ func parseCLI() AppConfig { timeoutSec := flag.Int("timeout-sec", 0, "request timeout sec") pollIntervalSec := flag.Float64("poll-interval-sec", 0, "poll interval sec") pollMaxRounds := flag.Int("poll-max-rounds", 0, "poll max rounds") + pprofEnabled := flag.Bool("pprof-enabled", false, "enable pprof debug server") + pprofAddr := flag.String("pprof-addr", "", "pprof listen address") + maxRequestBodyBytes := flag.Int64("max-request-body-bytes", 0, "max request body size in bytes for JSON API endpoints") userName := flag.String("user-name", "", "override user name") spaceName := flag.String("space-name", "", "override space name") flag.Parse() @@ -870,6 +926,15 @@ func parseCLI() AppConfig { if *pollMaxRounds > 0 { cfg.PollMaxRounds = *pollMaxRounds } + if *pprofEnabled { + cfg.Debug.PprofEnabled = true + } + if strings.TrimSpace(*pprofAddr) != "" { + cfg.Debug.PprofAddr = strings.TrimSpace(*pprofAddr) + } + if *maxRequestBodyBytes > 0 { + cfg.Limits.MaxRequestBodyBytes = *maxRequestBodyBytes + } if strings.TrimSpace(*userName) != "" { cfg.UserName = *userName } diff --git a/internal/app/conversations.go b/internal/app/conversations.go index affb6d7..767a529 100644 --- a/internal/app/conversations.go +++ b/internal/app/conversations.go @@ -58,6 +58,7 @@ type ConversationEntry struct { InputAttachments []ConversationAttachment `json:"input_attachments,omitempty"` OutputAttachments []UploadedAttachment `json:"output_attachments,omitempty"` Messages []ConversationMessage `json:"messages,omitempty"` + cachedPreview string `json:"-"` } type ConversationSummary struct { @@ -133,6 +134,7 @@ func newConversationStoreFromEntries(entries []ConversationEntry) *ConversationS store := newConversationStore() for _, entry := range entries { cloned := cloneConversationEntry(&entry) + refreshConversationDerivedFields(&cloned) store.items[cloned.ID] = &cloned store.order = append(store.order, cloned.ID) } @@ -285,17 +287,17 @@ func cloneConversationEntry(entry *ConversationEntry) ConversationEntry { return out } +func copyConversationEntryValue(entry *ConversationEntry) ConversationEntry { + if entry == nil { + return ConversationEntry{} + } + return *entry +} + func buildConversationSummary(entry *ConversationEntry) ConversationSummary { - preview := "" - for i := len(entry.Messages) - 1; i >= 0; i-- { - text := collapseWhitespace(entry.Messages[i].Content) - if text == "" && len(entry.Messages[i].Attachments) > 0 { - text = fmt.Sprintf("%d attachments", len(entry.Messages[i].Attachments)) - } - if text != "" { - preview = truncateRunes(text, 96) - break - } + preview := entry.cachedPreview + if preview == "" && len(entry.Messages) > 0 { + preview = conversationPreviewFromMessages(entry.Messages) } return ConversationSummary{ ID: entry.ID, @@ -327,6 +329,26 @@ func buildConversationSummary(entry *ConversationEntry) ConversationSummary { } } +func conversationPreviewFromMessages(messages []ConversationMessage) string { + for i := len(messages) - 1; i >= 0; i-- { + text := collapseWhitespace(messages[i].Content) + if text == "" && len(messages[i].Attachments) > 0 { + text = fmt.Sprintf("%d attachments", len(messages[i].Attachments)) + } + if text != "" { + return truncateRunes(text, 96) + } + } + return "" +} + +func refreshConversationDerivedFields(entry *ConversationEntry) { + if entry == nil { + return + } + entry.cachedPreview = conversationPreviewFromMessages(entry.Messages) +} + func conversationMessageSegments(entry *ConversationEntry) []conversationPromptSegment { if entry == nil || len(entry.Messages) == 0 { return nil @@ -410,7 +432,7 @@ func (s *ConversationStore) Create(req ConversationCreateRequest) ConversationEn if id == "" { id = "conv_" + strings.ReplaceAll(randomUUID(), "-", "") } - entry := &ConversationEntry{ + entry := ConversationEntry{ ID: id, Title: conversationTitle(req.Prompt, req.InputAttachments), Origin: "local", @@ -440,17 +462,19 @@ func (s *ConversationStore) Create(req ConversationCreateRequest) ConversationEn Attachments: cloneConversationAttachments(entry.InputAttachments), }) } + refreshConversationDerivedFields(&entry) s.mu.Lock() if s.items[id] != nil { id = "conv_" + strings.ReplaceAll(randomUUID(), "-", "") entry.ID = id } - s.items[id] = entry + entryPtr := &entry + s.items[id] = entryPtr s.order = append([]string{id}, s.order...) s.trimLocked() - cloned := cloneConversationEntry(entry) - summary := buildConversationSummary(entry) + cloned := copyConversationEntryValue(entryPtr) + summary := buildConversationSummary(entryPtr) s.mu.Unlock() s.broadcast(ConversationEvent{ @@ -458,7 +482,7 @@ func (s *ConversationStore) Create(req ConversationCreateRequest) ConversationEn ConversationID: id, At: now, Summary: &summary, - Conversation: &cloned, + Conversation: entryPtr, }) return cloned } @@ -469,39 +493,41 @@ func (s *ConversationStore) Continue(conversationID string, req ConversationCrea cloned ConversationEntry summary ConversationSummary ok bool + entry *ConversationEntry ) s.mu.Lock() - entry := s.items[conversationID] - if entry != nil { - entry.Source = firstNonEmpty(req.Source, entry.Source) - entry.Transport = firstNonEmpty(req.Transport, entry.Transport) + current := s.items[conversationID] + if current != nil { + next := cloneConversationEntry(current) + next.Source = firstNonEmpty(req.Source, next.Source) + next.Transport = firstNonEmpty(req.Transport, next.Transport) if req.Ephemeral { - entry.Ephemeral = true - entry.EphemeralReason = firstNonEmpty(strings.TrimSpace(req.EphemeralReason), entry.EphemeralReason) + next.Ephemeral = true + next.EphemeralReason = firstNonEmpty(strings.TrimSpace(req.EphemeralReason), next.EphemeralReason) if !req.AutoDeleteAt.IsZero() { - entry.AutoDeleteAt = timePointer(req.AutoDeleteAt) + next.AutoDeleteAt = timePointer(req.AutoDeleteAt) } } if clean := strings.TrimSpace(req.Model); clean != "" { - entry.Model = clean + next.Model = clean } if clean := strings.TrimSpace(req.NotionModel); clean != "" { - entry.NotionModel = clean + next.NotionModel = clean } - entry.UseWebSearch = req.UseWebSearch - entry.Status = "running" - entry.Error = "" - entry.InputAttachments = cloneConversationAttachments(req.InputAttachments) - entry.UpdatedAt = now - if len(entry.Messages) > 0 { - last := &entry.Messages[len(entry.Messages)-1] + next.UseWebSearch = req.UseWebSearch + next.Status = "running" + next.Error = "" + next.InputAttachments = cloneConversationAttachments(req.InputAttachments) + next.UpdatedAt = now + if len(next.Messages) > 0 { + last := &next.Messages[len(next.Messages)-1] if last.Role == "assistant" && last.Status != "completed" { last.Status = "failed" last.UpdatedAt = now } } if strings.TrimSpace(req.Prompt) != "" || len(req.InputAttachments) > 0 { - entry.Messages = append(entry.Messages, ConversationMessage{ + next.Messages = append(next.Messages, ConversationMessage{ ID: "msg_user_" + strings.ReplaceAll(randomUUID(), "-", ""), Role: "user", Status: "completed", @@ -511,8 +537,11 @@ func (s *ConversationStore) Continue(conversationID string, req ConversationCrea Attachments: cloneConversationAttachments(req.InputAttachments), }) } + refreshConversationDerivedFields(&next) + entry = &next + s.items[conversationID] = entry s.moveToFrontLocked(conversationID) - cloned = cloneConversationEntry(entry) + cloned = copyConversationEntryValue(entry) summary = buildConversationSummary(entry) ok = true } @@ -525,7 +554,7 @@ func (s *ConversationStore) Continue(conversationID string, req ConversationCrea ConversationID: conversationID, At: now, Summary: &summary, - Conversation: &cloned, + Conversation: entry, }) return cloned, nil } @@ -553,17 +582,22 @@ func (s *ConversationStore) SetEnvelopeIDs(conversationID string, responseID str var ( summary ConversationSummary ok bool + entry *ConversationEntry ) s.mu.Lock() - entry := s.items[conversationID] - if entry != nil { + current := s.items[conversationID] + if current != nil { + next := cloneConversationEntry(current) if strings.TrimSpace(responseID) != "" { - entry.ResponseID = strings.TrimSpace(responseID) + next.ResponseID = strings.TrimSpace(responseID) } if strings.TrimSpace(completionID) != "" { - entry.CompletionID = strings.TrimSpace(completionID) + next.CompletionID = strings.TrimSpace(completionID) } - entry.UpdatedAt = now + next.UpdatedAt = now + refreshConversationDerivedFields(&next) + entry = &next + s.items[conversationID] = entry s.moveToFrontLocked(conversationID) summary = buildConversationSummary(entry) ok = true @@ -575,6 +609,7 @@ func (s *ConversationStore) SetEnvelopeIDs(conversationID string, responseID str ConversationID: conversationID, At: now, Summary: &summary, + Conversation: entry, }) } } @@ -587,21 +622,26 @@ func (s *ConversationStore) AppendAssistantDelta(conversationID string, delta st now := time.Now().UTC() var ( summary ConversationSummary - msg ConversationMessage + msg *ConversationMessage ok bool + entry *ConversationEntry ) s.mu.Lock() - entry := s.items[conversationID] - if entry != nil { - assistant := s.ensureAssistantMessageLocked(entry, now) + current := s.items[conversationID] + if current != nil { + next := cloneConversationEntry(current) + assistant := s.ensureAssistantMessageLocked(&next, now) assistant.Content += delta assistant.Status = "streaming" assistant.UpdatedAt = now - entry.Status = "running" - entry.UpdatedAt = now + next.Status = "running" + next.UpdatedAt = now + refreshConversationDerivedFields(&next) + entry = &next + s.items[conversationID] = entry s.moveToFrontLocked(conversationID) summary = buildConversationSummary(entry) - msg = cloneConversationMessage(*assistant) + msg = assistant ok = true } s.mu.Unlock() @@ -612,7 +652,8 @@ func (s *ConversationStore) AppendAssistantDelta(conversationID string, delta st At: now, Delta: delta, Summary: &summary, - Message: &msg, + Conversation: entry, + Message: msg, }) } } @@ -620,46 +661,49 @@ func (s *ConversationStore) AppendAssistantDelta(conversationID string, delta st func (s *ConversationStore) Complete(conversationID string, result InferenceResult) { now := time.Now().UTC() var ( - cloned ConversationEntry summary ConversationSummary ok bool + entry *ConversationEntry ) s.mu.Lock() - entry := s.items[conversationID] - if entry != nil { - entry.Status = "completed" - entry.UpdatedAt = now - if entry.Ephemeral { - entry.AutoDeleteAt = timePointer(now.Add(sillyTavernQuietConversationTTL)) + current := s.items[conversationID] + if current != nil { + next := cloneConversationEntry(current) + next.Status = "completed" + next.UpdatedAt = now + if next.Ephemeral { + next.AutoDeleteAt = timePointer(now.Add(sillyTavernQuietConversationTTL)) } - entry.ThreadID = strings.TrimSpace(result.ThreadID) - entry.TraceID = strings.TrimSpace(result.TraceID) - entry.MessageID = strings.TrimSpace(result.MessageID) - entry.AccountEmail = strings.TrimSpace(result.AccountEmail) - entry.Error = "" - entry.OutputAttachments = cloneUploadedAttachments(result.Attachments) - assistant := s.ensureAssistantMessageLocked(entry, now) + next.ThreadID = strings.TrimSpace(result.ThreadID) + next.TraceID = strings.TrimSpace(result.TraceID) + next.MessageID = strings.TrimSpace(result.MessageID) + next.AccountEmail = strings.TrimSpace(result.AccountEmail) + next.Error = "" + next.OutputAttachments = cloneUploadedAttachments(result.Attachments) + assistant := s.ensureAssistantMessageLocked(&next, now) assistant.Status = "completed" assistant.Content = sanitizeAssistantVisibleText(result.Text) assistant.Attachments = summarizeUploadedAttachments(result.Attachments) assistant.UpdatedAt = now - if len(entry.Messages) > 0 { - entry.Messages[len(entry.Messages)-1] = cloneConversationMessage(*assistant) + if len(next.Messages) > 0 { + next.Messages[len(next.Messages)-1] = cloneConversationMessage(*assistant) } + refreshConversationDerivedFields(&next) + entry = &next + s.items[conversationID] = entry s.moveToFrontLocked(conversationID) - cloned = cloneConversationEntry(entry) summary = buildConversationSummary(entry) ok = true } s.mu.Unlock() if ok { s.broadcast(ConversationEvent{ - Type: "conversation.completed", - ConversationID: conversationID, - At: now, - Summary: &summary, - Conversation: &cloned, - }) + Type: "conversation.completed", + ConversationID: conversationID, + At: now, + Summary: &summary, + Conversation: entry, + }) } } @@ -670,41 +714,44 @@ func (s *ConversationStore) Fail(conversationID string, err error) { now := time.Now().UTC() message := strings.TrimSpace(err.Error()) var ( - cloned ConversationEntry summary ConversationSummary ok bool + entry *ConversationEntry ) s.mu.Lock() - entry := s.items[conversationID] - if entry != nil { - entry.Status = "failed" - entry.Error = message - entry.UpdatedAt = now - if entry.Ephemeral { - entry.AutoDeleteAt = timePointer(now.Add(sillyTavernQuietConversationTTL)) + current := s.items[conversationID] + if current != nil { + next := cloneConversationEntry(current) + next.Status = "failed" + next.Error = message + next.UpdatedAt = now + if next.Ephemeral { + next.AutoDeleteAt = timePointer(now.Add(sillyTavernQuietConversationTTL)) } - if len(entry.Messages) > 0 { - last := &entry.Messages[len(entry.Messages)-1] + if len(next.Messages) > 0 { + last := &next.Messages[len(next.Messages)-1] if last.Role == "assistant" && last.Status != "completed" { last.Status = "failed" last.UpdatedAt = now } } + refreshConversationDerivedFields(&next) + entry = &next + s.items[conversationID] = entry s.moveToFrontLocked(conversationID) - cloned = cloneConversationEntry(entry) summary = buildConversationSummary(entry) ok = true } s.mu.Unlock() if ok { s.broadcast(ConversationEvent{ - Type: "conversation.failed", - ConversationID: conversationID, - At: now, - Error: message, - Summary: &summary, - Conversation: &cloned, - }) + Type: "conversation.failed", + ConversationID: conversationID, + At: now, + Error: message, + Summary: &summary, + Conversation: entry, + }) } } @@ -761,7 +808,7 @@ func (s *ConversationStore) ListExpiredEphemeral(now time.Time, limit int) []Con if entry.AutoDeleteAt == nil || entry.AutoDeleteAt.After(now) { continue } - items = append(items, cloneConversationEntry(entry)) + items = append(items, copyConversationEntryValue(entry)) if len(items) >= limit { break } @@ -776,8 +823,7 @@ func (s *ConversationStore) Get(conversationID string) (ConversationEntry, bool) if entry == nil { return ConversationEntry{}, false } - cloned := cloneConversationEntry(entry) - return cloned, true + return copyConversationEntryValue(entry), true } func (s *ConversationStore) FindByThreadID(threadID string) (ConversationEntry, bool) { @@ -795,8 +841,7 @@ func (s *ConversationStore) FindByThreadID(threadID string) (ConversationEntry, if strings.TrimSpace(entry.ThreadID) != threadID { continue } - cloned := cloneConversationEntry(entry) - return cloned, true + return copyConversationEntryValue(entry), true } return ConversationEntry{}, false } @@ -820,8 +865,7 @@ func (s *ConversationStore) FindContinuationBySegments(history []conversationPro if !conversationSegmentsMatchSuffix(entrySegments, normalizedHistory) { continue } - cloned := cloneConversationEntry(entry) - return cloned, true + return copyConversationEntryValue(entry), true } return ConversationEntry{}, false } @@ -881,16 +925,18 @@ func (s *ServerState) deleteResponsesByConversationOrThread(conversationID strin return } s.mu.Lock() - for id, item := range s.ResponsesByID { - if (conversationID != "" && strings.TrimSpace(item.ConversationID) == conversationID) || - (threadID != "" && strings.TrimSpace(item.ThreadID) == threadID) { - delete(s.ResponsesByID, id) - } + if s.ResponseStore != nil { + s.ResponseStore.deleteByConversationOrThread(conversationID, threadID) } + sqliteWriter := s.sqliteWriter store := s.Store storeEnabled := store != nil && responsesPersistenceEnabled(s.Config) s.mu.Unlock() if storeEnabled { + if sqliteWriter != nil { + sqliteWriter.EnqueueDeleteResponsesByConversationOrThread(conversationID, threadID) + return + } if err := store.DeleteResponsesByConversationOrThread(conversationID, threadID); err != nil { log.Printf("[sqlite] delete responses conversation=%s thread=%s failed: %v", conversationID, threadID, err) } diff --git a/internal/app/httpclient_audit.md b/internal/app/httpclient_audit.md new file mode 100644 index 0000000..34880ef --- /dev/null +++ b/internal/app/httpclient_audit.md @@ -0,0 +1,69 @@ +# T-4-2 HTTP Client / Transport Audit + +Date: 2026-05-02 + +## Scope checked + +- `internal/app/notion_client.go` +- `internal/app/login_helper.go` +- `internal/app/notion_client_login_transport.go` +- `internal/app/account_discovery.go` +- `internal/app/session_refresh.go` + +## Findings + +### 1) NotionAI request path creates new `http.Client`/`Transport` per `NotionAIClient` + +- Location: `internal/app/notion_client.go:newNotionAIClientWithMode` +- Behavior: + - Builds a fresh `http.Transport` and `http.Client` every time a `NotionAIClient` is created. + - In dispatch paths (`runPromptWithSession*`) this can happen frequently, so connection pools are not reused across those client instances. +- Impact: + - Potentially higher connect/TLS handshake overhead under sustained traffic. + - Extra pressure on upstream and local sockets due to fragmented pools. + +### 2) Login helper path also creates fresh `http.Client` + +- Location: `internal/app/login_helper.go:newNotionLoginSession` +- Behavior: + - Creates a new cookie jar and `http.Client` per login session call. +- Notes: + - This path is less hot than inference path, but still relevant for repeated refresh/login workflows. + +### 3) Proxy/header behavior correctness constraints + +- Proxy resolution and resin headers are request/account dependent: + - `ProxyResolver.ResolveProxyForRequest(accountEmail, targetURL)` can vary by account/policy. + - `postJSONResponse` overlays per-request proxy headers (e.g. resin account header). +- Any reuse strategy must preserve: + - account-aware proxy resolution + - per-request header injection behavior + - stream vs non-stream timeout difference + +## Recommendation + +Introduce a transport cache in `internal/app/notion_client.go`: + +- Cache key dimensions: + - normalized upstream base/origin/host/tls server name + - proxy mode + proxy urls + resin settings + - account email key (for account-specific proxy routing) + - streaming flag is **not** required in transport key (timeout is on `http.Client`, not transport) +- Cache value: + - reusable `*http.Transport` +- Then construct short-lived `http.Client` wrappers over cached transport: + - standard client timeout = request timeout + - streaming client timeout = 0 +- Add evidence: + - metric for transport/client creation count + - benchmark around repeated client creation path if needed + +## Current status + +- Updated 2026-05-02 follow-up: + - Implemented transport cache in `newNotionAIClientWithMode` via keyed map + RWMutex. + - Added runtime visibility metric: `notion2api_http_transport_cache_total` (`hit_rlock`, `hit_lock`, `miss_new`) exposed through `/debug/vars`. + - Added tests validating: + - same account/config => transport reuse + - different account proxy policy => transport separation + - Added benchmark `BenchmarkNewNotionAIClientWithModeTransportCache` showing warm-cache path lower alloc/op and ns/op than forced cold-cache path. diff --git a/internal/app/login_helper.go b/internal/app/login_helper.go index b2e3a02..2d967a5 100644 --- a/internal/app/login_helper.go +++ b/internal/app/login_helper.go @@ -162,17 +162,18 @@ func writeLoginStorageState(path string, payload loginStorageState) error { return writePrettyJSONFile(path, payload) } -func newNotionLoginSession(timeout time.Duration, upstream NotionUpstream, resolver *ProxyResolver, accountEmail string) (*loginHTTPSession, error) { +func newNotionLoginSession(timeout time.Duration, upstream NotionUpstream, resolver *ProxyResolver, accountEmail string, cfg AppConfig) (*loginHTTPSession, error) { jar, err := cookiejar.New(nil) if err != nil { return nil, err } return &loginHTTPSession{ - Client: &http.Client{Timeout: timeout, Jar: jar}, - ProxyResolver: resolver, - AccountEmail: strings.TrimSpace(accountEmail), - Timeout: timeout, - Upstream: upstream, + Client: &http.Client{Timeout: timeout, Jar: jar}, + ProxyResolver: resolver, + AccountEmail: strings.TrimSpace(accountEmail), + Timeout: timeout, + Upstream: upstream, + UseSurfHelperTransport: cfg.Features.UseSurfHelperTransport, }, nil } @@ -348,7 +349,7 @@ func fetchLoginBootstrap(ctx context.Context, session *loginHTTPSession, upstrea "accept-language": "zh-CN,zh;q=0.9", "user-agent": notionLoginUA, } - status, respHeaders, body, err := loginWreqDoRequest(ctx, session, http.MethodGet, upstream.LoginURL(), headers, nil) + status, respHeaders, body, err := loginTransportDoRequest(ctx, session, http.MethodGet, upstream.LoginURL(), headers, nil) if err != nil { return loginBootstrap{}, err } @@ -395,7 +396,7 @@ func postNotionLoginJSON(ctx context.Context, session *loginHTTPSession, upstrea "notion-audit-log-platform": "web", "x-notion-active-user-header": strings.TrimSpace(activeUserID), } - status, respHeaders, respBody, err := loginWreqDoRequest(ctx, session, http.MethodPost, targetURL, headers, body) + status, respHeaders, respBody, err := loginTransportDoRequest(ctx, session, http.MethodPost, targetURL, headers, body) if err != nil { return nil, err } @@ -613,7 +614,7 @@ func StartEmailLogin(ctx context.Context, cfg AppConfig, req LoginStartRequest) upstream := cfg.NotionUpstream() resolver := NewProxyResolver(cfg) - session, err := newNotionLoginSession(helperTimeout(cfg), upstream, resolver, firstNonEmpty(req.AccountEmail, req.Email)) + session, err := newNotionLoginSession(helperTimeout(cfg), upstream, resolver, firstNonEmpty(req.AccountEmail, req.Email), cfg) if err != nil { return failLoginState(req.PendingPath, state, err) } @@ -682,7 +683,7 @@ func VerifyEmailLogin(ctx context.Context, cfg AppConfig, req LoginVerifyRequest upstream := cfg.NotionUpstream() resolver := NewProxyResolver(cfg) - session, err := newNotionLoginSession(helperTimeout(cfg), upstream, resolver, firstNonEmpty(req.AccountEmail, req.Email)) + session, err := newNotionLoginSession(helperTimeout(cfg), upstream, resolver, firstNonEmpty(req.AccountEmail, req.Email), cfg) if err != nil { return failLoginState(req.PendingPath, pending, err) } diff --git a/internal/app/main.go b/internal/app/main.go index 5e96d4a..278b4bc 100644 --- a/internal/app/main.go +++ b/internal/app/main.go @@ -1,14 +1,20 @@ package app import ( + "bytes" "context" "encoding/json" + "errors" + "expvar" "fmt" + "io" "log" "net/http" + _ "net/http/pprof" "runtime/debug" "strings" "sync" + "sync/atomic" "time" ) @@ -20,21 +26,35 @@ type StoredResponse struct { AccountEmail string } +type snapshotBundle struct { + Config AppConfig + Session SessionInfo + ModelRegistry ModelRegistry + DispatchOrder []NotionAccount +} + type ServerState struct { - mu sync.RWMutex - refreshMu sync.Mutex - Config AppConfig - Session SessionInfo - Client *NotionAIClient - Store *SQLiteStore - ModelRegistry ModelRegistry - ResponsesByID map[string]StoredResponse - Conversations *ConversationStore - AdminTokens map[string]time.Time - AdminLoginAttempts map[string]AdminLoginAttempt - AccountDispatchSlots map[string]accountDispatchState - LastSessionRefresh time.Time - LastSessionRefreshError string + mu sync.RWMutex + refreshMu sync.Mutex + Config AppConfig + Session SessionInfo + Client *NotionAIClient + Store *SQLiteStore + ModelRegistry ModelRegistry + ResponseStore *responseStore + Conversations *ConversationStore + AdminTokens map[string]time.Time + AdminLoginAttempts map[string]AdminLoginAttempt + DispatchProbeCache *probeCache + LastSessionRefresh time.Time + LastSessionRefreshError string + responseStoreCleanupCancel context.CancelFunc + sqliteWriter *SQLiteWriter + snap atomic.Pointer[snapshotBundle] + slots atomic.Pointer[map[string]*accountSlot] + cachedHealthzStaticJSON atomic.Pointer[[]byte] + cachedModelsListJSON atomic.Pointer[[]byte] + cachedModelByIDJSON atomic.Pointer[map[string][]byte] } type accountDispatchState struct { @@ -42,6 +62,38 @@ type accountDispatchState struct { InFlight int } +type accountSlot struct { + max atomic.Int32 + inflight atomic.Int32 +} + +type healthzStaticPayload struct { + OK bool `json:"ok"` + DefaultModel string `json:"default_model"` + ModelCount int `json:"model_count"` + UserEmail string `json:"user_email"` + SpaceID string `json:"space_id"` + ActiveAccount string `json:"active_account"` + SessionRefreshEnable bool `json:"session_refresh_enabled"` +} + +type publicModelPayload struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + OwnedBy string `json:"owned_by"` + Name string `json:"name"` + Family string `json:"family"` + Group string `json:"group"` + Beta bool `json:"beta"` + NotionModel string `json:"notion_model"` +} + +type publicModelsListPayload struct { + Object string `json:"object"` + Data []publicModelPayload `json:"data"` +} + type App struct { State *ServerState runPromptOverride func(*http.Request, PromptRunRequest) (InferenceResult, error) @@ -56,8 +108,15 @@ const ( ephemeralConversationCleanupInterval = time.Minute ephemeralConversationCleanupBatchSize = 24 sillyTavernQuietConversationTTL = 10 * time.Minute + corsAllowOrigin = "*" + corsAllowHeaders = "Authorization, Content-Type, X-Admin-Token" + corsAllowMethods = "GET, POST, PUT, DELETE, OPTIONS" ) +var errRequestTooLarge = errors.New("request body too large") +var responseStorePruneTotalMetric = expvar.NewMap("notion2api_response_store_prune_total") +var testHookResponseStoreCleanupInterval time.Duration + type continuationTarget struct { Conversation ConversationEntry Session *conversationContinuationState @@ -112,28 +171,68 @@ func normalizeAccountMaxConcurrency(raw int) int { return raw } -func (s *ServerState) initializeAccountDispatchSlotsLocked() { - if s.AccountDispatchSlots == nil { - s.AccountDispatchSlots = map[string]accountDispatchState{} +func clampSlotInFlight(slot *accountSlot, max int32) int32 { + if slot == nil { + return 0 + } + if max <= 0 { + max = 1 + } + for { + current := slot.inflight.Load() + if current < 0 { + if slot.inflight.CompareAndSwap(current, 0) { + return 0 + } + continue + } + if current <= max { + return current + } + if slot.inflight.CompareAndSwap(current, max) { + return max + } } - next := map[string]accountDispatchState{} +} + +func (s *ServerState) rebuildAccountSlotsLocked() { + if s == nil { + return + } + var previous map[string]*accountSlot + if loaded := s.slots.Load(); loaded != nil { + previous = *loaded + } + next := make(map[string]*accountSlot, len(s.Config.Accounts)) for _, account := range s.Config.Accounts { - emailKey := canonicalEmailKey(account.Email) + emailKey := getAccountEmailKey(account) if emailKey == "" { continue } - maxConcurrency := normalizeAccountMaxConcurrency(account.MaxConcurrency) - state := s.AccountDispatchSlots[emailKey] - state.MaxConcurrency = maxConcurrency - if state.InFlight < 0 { - state.InFlight = 0 - } - if state.InFlight > state.MaxConcurrency { - state.InFlight = state.MaxConcurrency + maxConcurrency := int32(normalizeAccountMaxConcurrency(account.MaxConcurrency)) + if existing := previous[emailKey]; existing != nil { + existing.max.Store(maxConcurrency) + clampSlotInFlight(existing, maxConcurrency) + next[emailKey] = existing + continue } - next[emailKey] = state + slot := &accountSlot{} + slot.max.Store(maxConcurrency) + next[emailKey] = slot + } + s.slots.Store(&next) + syncDispatchSlotInflightFromSlots(next) +} + +func (s *ServerState) loadAccountSlots() map[string]*accountSlot { + if s == nil { + return nil } - s.AccountDispatchSlots = next + loaded := s.slots.Load() + if loaded == nil { + return nil + } + return *loaded } func (s *ServerState) TryAcquireAccountDispatchSlot(email string) bool { @@ -141,19 +240,24 @@ func (s *ServerState) TryAcquireAccountDispatchSlot(email string) bool { if emailKey == "" { return false } - s.mu.Lock() - defer s.mu.Unlock() - s.initializeAccountDispatchSlotsLocked() - state, ok := s.AccountDispatchSlots[emailKey] - if !ok { + slot := s.loadAccountSlots()[emailKey] + if slot == nil { return false } - if state.InFlight >= state.MaxConcurrency { - return false + for { + maxConcurrency := slot.max.Load() + if maxConcurrency <= 0 { + maxConcurrency = 1 + } + inflight := slot.inflight.Load() + if inflight >= maxConcurrency { + return false + } + if slot.inflight.CompareAndSwap(inflight, inflight+1) { + setDispatchSlotInflight(emailKey, int(inflight+1)) + return true + } } - state.InFlight++ - s.AccountDispatchSlots[emailKey] = state - return true } func (s *ServerState) ReleaseAccountDispatchSlot(email string) { @@ -161,19 +265,21 @@ func (s *ServerState) ReleaseAccountDispatchSlot(email string) { if emailKey == "" { return } - s.mu.Lock() - defer s.mu.Unlock() - if s.AccountDispatchSlots == nil { - return - } - state, ok := s.AccountDispatchSlots[emailKey] - if !ok { + slot := s.loadAccountSlots()[emailKey] + if slot == nil { return } - if state.InFlight > 0 { - state.InFlight-- + for { + inflight := slot.inflight.Load() + if inflight <= 0 { + setDispatchSlotInflight(emailKey, 0) + return + } + if slot.inflight.CompareAndSwap(inflight, inflight-1) { + setDispatchSlotInflight(emailKey, int(inflight-1)) + return + } } - s.AccountDispatchSlots[emailKey] = state } func (s *ServerState) RemainingAccountDispatchSlots(email string) int { @@ -181,24 +287,27 @@ func (s *ServerState) RemainingAccountDispatchSlots(email string) int { if emailKey == "" { return 0 } - s.mu.Lock() - defer s.mu.Unlock() - s.initializeAccountDispatchSlotsLocked() - state, ok := s.AccountDispatchSlots[emailKey] - if !ok { + slot := s.loadAccountSlots()[emailKey] + if slot == nil { return 0 } - remaining := state.MaxConcurrency - state.InFlight + maxConcurrency := slot.max.Load() + if maxConcurrency <= 0 { + maxConcurrency = 1 + } + inflight := slot.inflight.Load() + remaining := int(maxConcurrency - inflight) if remaining < 0 { - remaining = 0 + return 0 } return remaining } func (s *ServerState) AvailableDispatchCapacity(emails []string) int { - s.mu.Lock() - defer s.mu.Unlock() - s.initializeAccountDispatchSlotsLocked() + slots := s.loadAccountSlots() + if len(slots) == 0 { + return 0 + } total := 0 seen := map[string]struct{}{} for _, email := range emails { @@ -210,11 +319,16 @@ func (s *ServerState) AvailableDispatchCapacity(emails []string) int { continue } seen[emailKey] = struct{}{} - state, ok := s.AccountDispatchSlots[emailKey] - if !ok { + slot := slots[emailKey] + if slot == nil { continue } - remaining := state.MaxConcurrency - state.InFlight + maxConcurrency := slot.max.Load() + if maxConcurrency <= 0 { + maxConcurrency = 1 + } + inflight := slot.inflight.Load() + remaining := int(maxConcurrency - inflight) if remaining > 0 { total += remaining } @@ -223,17 +337,31 @@ func (s *ServerState) AvailableDispatchCapacity(emails []string) int { } func (s *ServerState) AccountDispatchSnapshot() map[string]accountDispatchState { - s.mu.Lock() - defer s.mu.Unlock() - s.initializeAccountDispatchSlotsLocked() - out := make(map[string]accountDispatchState, len(s.AccountDispatchSlots)) - for key, value := range s.AccountDispatchSlots { - out[key] = value + slots := s.loadAccountSlots() + out := make(map[string]accountDispatchState, len(slots)) + for key, slot := range slots { + if slot == nil { + continue + } + maxConcurrency := int(slot.max.Load()) + if maxConcurrency <= 0 { + maxConcurrency = 1 + } + inflight := int(slot.inflight.Load()) + if inflight < 0 { + inflight = 0 + } + if inflight > maxConcurrency { + inflight = maxConcurrency + } + out[key] = accountDispatchState{ + MaxConcurrency: maxConcurrency, + InFlight: inflight, + } } return out } - func maxFloat(a float64, b float64) float64 { if a > b { return a @@ -265,13 +393,13 @@ func newServerState(cfg AppConfig) (*ServerState, error) { return nil, err } state := &ServerState{ - ResponsesByID: map[string]StoredResponse{}, Conversations: newConversationStore(), AdminTokens: map[string]time.Time{}, AdminLoginAttempts: map[string]AdminLoginAttempt{}, - AccountDispatchSlots: map[string]accountDispatchState{}, + DispatchProbeCache: newProbeCache(), Store: store, } + state.ResponseStore = newResponseStore(time.Duration(maxInt(cfg.Responses.StoreTTLSeconds, 1)) * time.Second) persistedAccountsLoaded := false if store != nil { accounts, activeAccount, ok, loadErr := store.LoadAccounts() @@ -305,7 +433,10 @@ func newServerState(cfg AppConfig) (*ServerState, error) { _ = store.Close() return nil, loadErr } - state.ResponsesByID = responses + if state.ResponseStore == nil { + state.ResponseStore = newResponseStore(time.Duration(maxInt(state.Config.Responses.StoreTTLSeconds, 1)) * time.Second) + } + state.ResponseStore.replaceAll(responses) } if conversationSnapshotsPersistenceEnabled(state.Config) { conversations, loadErr := store.LoadConversations() @@ -321,7 +452,9 @@ func newServerState(cfg AppConfig) (*ServerState, error) { return nil, saveErr } } + state.sqliteWriter = newSQLiteWriter(store, time.Duration(maxInt(state.Config.Responses.StoreTTLSeconds, 1))*time.Second) } + state.startResponseStoreCleanupLoop(context.Background()) return state, nil } @@ -352,15 +485,42 @@ func (s *ServerState) ApplyConfig(cfg AppConfig) error { s.Session = session s.ModelRegistry = registry s.Client = client + if s.sqliteWriter != nil { + s.sqliteWriter.SetTTL(time.Duration(maxInt(cfg.Responses.StoreTTLSeconds, 1)) * time.Second) + } + s.rebuildAccountSlotsLocked() + s.updateSnapshotBundleLocked() + s.rebuildStaticJSONCachesLocked() return nil } func (s *ServerState) Snapshot() (AppConfig, SessionInfo, ModelRegistry) { + if s == nil { + return AppConfig{}, SessionInfo{}, ModelRegistry{} + } + if snap := s.snap.Load(); snap != nil { + return snap.Config, snap.Session, snap.ModelRegistry + } s.mu.RLock() defer s.mu.RUnlock() return s.Config, s.Session, s.ModelRegistry } +func (s *ServerState) updateSnapshotBundleLocked() { + if s == nil { + return + } + now := time.Now() + dispatchOrder := buildDispatchCandidateOrder(s.Config, now) + bundle := &snapshotBundle{ + Config: s.Config, + Session: s.Session, + ModelRegistry: s.ModelRegistry, + DispatchOrder: dispatchOrder, + } + s.snap.Store(bundle) +} + func (s *ServerState) SaveAndApply(cfg AppConfig) error { cfg = normalizeConfig(cfg) if err := validateConfiguredAPIKey(cfg); err != nil { @@ -382,6 +542,17 @@ func (s *ServerState) SaveAndApply(cfg AppConfig) error { return err } } + s.mu.Lock() + if s.ResponseStore == nil { + s.ResponseStore = newResponseStore(time.Duration(maxInt(cfg.Responses.StoreTTLSeconds, 1)) * time.Second) + } else { + s.ResponseStore.setTTL(time.Duration(maxInt(cfg.Responses.StoreTTLSeconds, 1)) * time.Second) + } + s.updateSnapshotBundleLocked() + s.mu.Unlock() + if canonicalEmailKey(current.ActiveAccount) != canonicalEmailKey(cfg.ActiveAccount) && s.DispatchProbeCache != nil { + s.DispatchProbeCache.invalidateAll() + } return nil } @@ -394,16 +565,6 @@ func (s *ServerState) conversationPersistenceStore() *SQLiteStore { return s.Store } -func (s *ServerState) cleanupExpiredResponsesLocked(now time.Time) { - ttlSeconds := maxInt(s.Config.Responses.StoreTTLSeconds, 1) - ttl := time.Duration(ttlSeconds) * time.Second - for id, item := range s.ResponsesByID { - if now.Sub(item.CreatedAt) > ttl { - delete(s.ResponsesByID, id) - } - } -} - func (s *ServerState) saveResponse(responseID string, payload map[string]any, conversationID string, threadID string) { s.saveResponseWithAccount(responseID, payload, conversationID, threadID, "") } @@ -411,24 +572,33 @@ func (s *ServerState) saveResponse(responseID string, payload map[string]any, co func (s *ServerState) saveResponseWithAccount(responseID string, payload map[string]any, conversationID string, threadID string, accountEmail string) { now := time.Now().UTC() s.mu.Lock() - s.cleanupExpiredResponsesLocked(now) - s.ResponsesByID[responseID] = StoredResponse{ + store := s.ResponseStore + if store == nil { + store = newResponseStore(time.Duration(maxInt(s.Config.Responses.StoreTTLSeconds, 1)) * time.Second) + s.ResponseStore = store + } + store.save(responseID, StoredResponse{ Payload: payload, CreatedAt: now, ConversationID: strings.TrimSpace(conversationID), ThreadID: strings.TrimSpace(threadID), AccountEmail: strings.TrimSpace(accountEmail), - } - store := s.Store + }, now) + sqliteWriter := s.sqliteWriter + sqliteStore := s.Store ttl := time.Duration(maxInt(s.Config.Responses.StoreTTLSeconds, 1)) * time.Second - storeEnabled := store != nil && responsesPersistenceEnabled(s.Config) + storeEnabled := sqliteStore != nil && responsesPersistenceEnabled(s.Config) s.mu.Unlock() if storeEnabled { - if err := store.SaveResponse(responseID, payload, now, conversationID, threadID, accountEmail); err != nil { + if sqliteWriter != nil { + sqliteWriter.EnqueueSaveResponse(responseID, payload, now, conversationID, threadID, accountEmail) + return + } + if err := sqliteStore.SaveResponse(responseID, payload, now, conversationID, threadID, accountEmail); err != nil { log.Printf("[sqlite] save response %s failed: %v", responseID, err) return } - if err := store.DeleteExpiredResponses(ttl); err != nil { + if err := sqliteStore.DeleteExpiredResponses(ttl); err != nil { log.Printf("[sqlite] cleanup responses failed: %v", err) } } @@ -445,12 +615,10 @@ func (s *ServerState) getResponse(responseID string) (map[string]any, bool) { func (s *ServerState) getStoredResponse(responseID string) (StoredResponse, bool) { s.mu.Lock() defer s.mu.Unlock() - s.cleanupExpiredResponsesLocked(time.Now()) - payload, ok := s.ResponsesByID[responseID] - if !ok { + if s.ResponseStore == nil { return StoredResponse{}, false } - return payload, true + return s.ResponseStore.get(responseID, time.Now().UTC()) } func (s *ServerState) loadConversationContinuationStateByConversationID(conversationID string) (*conversationContinuationState, error) { @@ -548,24 +716,204 @@ func (s *ServerState) invalidateConversationSession(sessionID string, status str func (s *ServerState) Close() error { s.mu.RLock() store := s.Store + cancelCleanup := s.responseStoreCleanupCancel + sqliteWriter := s.sqliteWriter s.mu.RUnlock() + if cancelCleanup != nil { + cancelCleanup() + } + if sqliteWriter != nil { + sqliteWriter.Close() + } if store == nil { return nil } return store.Close() } +func (s *ServerState) startResponseStoreCleanupLoop(parent context.Context) { + if s == nil { + return + } + if parent == nil { + parent = context.Background() + } + interval := responseStoreCleanupInterval + if testHookResponseStoreCleanupInterval > 0 { + interval = testHookResponseStoreCleanupInterval + } + ctx, cancel := context.WithCancel(parent) + s.mu.Lock() + s.responseStoreCleanupCancel = cancel + s.mu.Unlock() + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + s.runResponseStoreCleanupOnce(time.Now().UTC()) + } + } + }() +} + +func (s *ServerState) runResponseStoreCleanupOnce(now time.Time) int { + if s == nil { + return 0 + } + s.mu.Lock() + defer s.mu.Unlock() + if s.ResponseStore == nil { + return 0 + } + removed := s.ResponseStore.pruneExpired(now) + if removed > 0 { + responseStorePruneTotalMetric.Add("expired_entries", int64(removed)) + } + return removed +} + +func buildPublicModelPayload(entry ModelDefinition) publicModelPayload { + return publicModelPayload{ + ID: entry.ID, + Object: "model", + Created: 0, + OwnedBy: "notion2api", + Name: entry.Name, + Family: entry.Family, + Group: entry.Group, + Beta: entry.Beta, + NotionModel: entry.NotionModel, + } +} + +func buildPublicModelsListPayload(registry ModelRegistry) publicModelsListPayload { + items := make([]publicModelPayload, 0, len(registry.Entries)) + for _, entry := range registry.Entries { + if !entry.Enabled { + continue + } + items = append(items, buildPublicModelPayload(entry)) + } + return publicModelsListPayload{ + Object: "list", + Data: items, + } +} + +func cloneBytes(src []byte) []byte { + if len(src) == 0 { + return nil + } + dst := make([]byte, len(src)) + copy(dst, src) + return dst +} + +func cloneBytesMap(src map[string][]byte) map[string][]byte { + if len(src) == 0 { + return nil + } + dst := make(map[string][]byte, len(src)) + for key, value := range src { + dst[key] = cloneBytes(value) + } + return dst +} + +func (s *ServerState) rebuildStaticJSONCachesLocked() { + healthPayload := healthzStaticPayload{ + OK: true, + DefaultModel: s.Config.DefaultPublicModel(), + ModelCount: len(s.ModelRegistry.Entries), + UserEmail: s.Session.UserEmail, + SpaceID: s.Session.SpaceID, + ActiveAccount: s.Config.ActiveAccount, + SessionRefreshEnable: s.Config.ResolveSessionRefresh().Enabled, + } + healthBody, err := json.Marshal(healthPayload) + if err == nil { + healthBodyCopy := cloneBytes(healthBody) + s.cachedHealthzStaticJSON.Store(&healthBodyCopy) + } else { + s.cachedHealthzStaticJSON.Store(nil) + } + + modelsPayload := buildPublicModelsListPayload(s.ModelRegistry) + modelsBody, err := json.Marshal(modelsPayload) + if err == nil { + modelsBodyCopy := cloneBytes(modelsBody) + s.cachedModelsListJSON.Store(&modelsBodyCopy) + } else { + s.cachedModelsListJSON.Store(nil) + } + + modelByID := make(map[string][]byte, len(s.ModelRegistry.Entries)) + for _, entry := range s.ModelRegistry.Entries { + if !entry.Enabled { + continue + } + body, marshalErr := json.Marshal(buildPublicModelPayload(entry)) + if marshalErr != nil { + continue + } + modelByID[normalizeLookupKey(entry.ID)] = cloneBytes(body) + } + modelByIDCopy := cloneBytesMap(modelByID) + s.cachedModelByIDJSON.Store(&modelByIDCopy) +} + +func writeJSONBytes(w http.ResponseWriter, status int, body []byte) { + applyCORSHeaders(w) + w.Header().Set("X-Notion2API", "1") + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(status) + _, _ = w.Write(body) +} + +func appendHealthzRuntimeFields(body []byte, sessionReady bool, lastRefresh time.Time, lastRefreshError string) []byte { + trimmed := bytes.TrimSpace(body) + if len(trimmed) == 0 || trimmed[len(trimmed)-1] != '}' { + trimmed = []byte(`{"ok":true}`) + } + trimmed = bytes.TrimSuffix(trimmed, []byte("}")) + tail := map[string]any{ + "session_ready": sessionReady, + "last_session_refresh": formatTimeOrEmpty(lastRefresh), + "last_session_refresh_error": lastRefreshError, + } + tailBody, err := json.Marshal(tail) + if err != nil { + return body + } + tailBody = bytes.TrimPrefix(tailBody, []byte("{")) + out := make([]byte, 0, len(trimmed)+1+len(tailBody)) + out = append(out, trimmed...) + if len(trimmed) > 1 { + out = append(out, ',') + } + out = append(out, tailBody...) + return out +} + +func applyCORSHeaders(w http.ResponseWriter) { + w.Header().Set("Access-Control-Allow-Origin", corsAllowOrigin) + w.Header().Set("Access-Control-Allow-Headers", corsAllowHeaders) + w.Header().Set("Access-Control-Allow-Methods", corsAllowMethods) +} + func writeJSON(w http.ResponseWriter, status int, payload any) { body, err := json.Marshal(payload) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } + applyCORSHeaders(w) w.Header().Set("X-Notion2API", "1") w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Admin-Token") - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") w.WriteHeader(status) _, _ = w.Write(body) } @@ -581,13 +929,28 @@ func writeOpenAIError(w http.ResponseWriter, status int, message string, errorTy }) } +func writeInvalidBodyError(w http.ResponseWriter, err error) { + if errors.Is(err, errRequestTooLarge) { + writeOpenAIError(w, http.StatusRequestEntityTooLarge, "request body exceeds configured limit", "invalid_request_error", "request_too_large") + return + } + writeOpenAIError(w, http.StatusBadRequest, err.Error(), "invalid_request_error", nilString()) +} + func nilString() string { return "" } -func decodeBody(r *http.Request) (map[string]any, error) { - defer r.Body.Close() - decoder := json.NewDecoder(r.Body) +func decodeBodyWithLimit(w http.ResponseWriter, r *http.Request, maxBytes int64) (map[string]any, error) { + raw, err := decodeBodyRawWithLimit(w, r, maxBytes) + if err != nil { + return nil, err + } + return decodeBodyMapFromRaw(raw) +} + +func decodeBodyMapFromRaw(raw []byte) (map[string]any, error) { + decoder := json.NewDecoder(bytes.NewReader(raw)) decoder.UseNumber() var payload map[string]any if err := decoder.Decode(&payload); err != nil { @@ -599,6 +962,65 @@ func decodeBody(r *http.Request) (map[string]any, error) { return payload, nil } +func decodeBodyRawWithLimit(w http.ResponseWriter, r *http.Request, maxBytes int64) ([]byte, error) { + if maxBytes > 0 && w != nil { + r.Body = http.MaxBytesReader(w, r.Body, maxBytes) + } + defer r.Body.Close() + body, err := io.ReadAll(r.Body) + if err != nil { + var maxErr *http.MaxBytesError + if errors.As(err, &maxErr) { + return nil, errRequestTooLarge + } + return nil, fmt.Errorf("invalid json: %w", err) + } + trimmed := bytes.TrimSpace(body) + if len(trimmed) == 0 { + return []byte("{}"), nil + } + var raw json.RawMessage + if err := json.Unmarshal(trimmed, &raw); err != nil { + var maxErr *http.MaxBytesError + if errors.As(err, &maxErr) { + return nil, errRequestTooLarge + } + return nil, fmt.Errorf("invalid json: %w", err) + } + normalized := bytes.TrimSpace(raw) + if len(normalized) == 0 { + return []byte("{}"), nil + } + return normalized, nil +} + +func (a *App) decodeBody(w http.ResponseWriter, r *http.Request) (map[string]any, error) { + raw, err := a.decodeBodyRaw(w, r) + if err != nil { + return nil, err + } + return decodeBodyMapFromRaw(raw) +} + +func (a *App) decodeBodyRaw(w http.ResponseWriter, r *http.Request) ([]byte, error) { + maxBytes := int64(0) + if a != nil && a.State != nil { + cfg, _, _ := a.State.Snapshot() + maxBytes = cfg.Limits.MaxRequestBodyBytes + } + return decodeBodyRawWithLimit(w, r, maxBytes) +} + +func decodeTypedBodyFromRaw[T any](raw []byte) (T, error) { + var typed T + decoder := json.NewDecoder(bytes.NewReader(raw)) + decoder.UseNumber() + if err := decoder.Decode(&typed); err != nil { + return typed, fmt.Errorf("invalid json: %w", err) + } + return typed, nil +} + func (a *App) authOK(w http.ResponseWriter, r *http.Request) bool { cfg, _, _ := a.State.Snapshot() expected := strings.TrimSpace(cfg.APIKey) @@ -614,12 +1036,18 @@ func (a *App) authOK(w http.ResponseWriter, r *http.Request) bool { } func (a *App) serveHealthz(w http.ResponseWriter) { - cfg, session, registry := a.State.Snapshot() a.State.mu.RLock() sessionReady := a.State.Client != nil lastRefresh := a.State.LastSessionRefresh lastRefreshError := a.State.LastSessionRefreshError + cached := a.State.cachedHealthzStaticJSON.Load() a.State.mu.RUnlock() + if cached != nil { + body := appendHealthzRuntimeFields(*cached, sessionReady, lastRefresh, lastRefreshError) + writeJSONBytes(w, http.StatusOK, body) + return + } + cfg, session, registry := a.State.Snapshot() writeJSON(w, http.StatusOK, map[string]any{ "ok": true, "default_model": cfg.DefaultPublicModel(), @@ -635,28 +1063,13 @@ func (a *App) serveHealthz(w http.ResponseWriter) { } func (a *App) serveModels(w http.ResponseWriter) { - _, _, registry := a.State.Snapshot() - items := make([]map[string]any, 0, len(registry.Entries)) - for _, entry := range registry.Entries { - if !entry.Enabled { - continue - } - items = append(items, map[string]any{ - "id": entry.ID, - "object": "model", - "created": 0, - "owned_by": "notion2api", - "name": entry.Name, - "family": entry.Family, - "group": entry.Group, - "beta": entry.Beta, - "notion_model": entry.NotionModel, - }) + cached := a.State.cachedModelsListJSON.Load() + if cached != nil { + writeJSONBytes(w, http.StatusOK, *cached) + return } - writeJSON(w, http.StatusOK, map[string]any{ - "object": "list", - "data": items, - }) + _, _, registry := a.State.Snapshot() + writeJSON(w, http.StatusOK, buildPublicModelsListPayload(registry)) } func (a *App) serveModelByID(w http.ResponseWriter, path string) { @@ -667,17 +1080,13 @@ func (a *App) serveModelByID(w http.ResponseWriter, path string) { writeOpenAIError(w, http.StatusNotFound, "model not found", "invalid_request_error", "model_not_found") return } - writeJSON(w, http.StatusOK, map[string]any{ - "id": entry.ID, - "object": "model", - "created": 0, - "owned_by": "notion2api", - "name": entry.Name, - "family": entry.Family, - "group": entry.Group, - "beta": entry.Beta, - "notion_model": entry.NotionModel, - }) + if cached := a.State.cachedModelByIDJSON.Load(); cached != nil { + if body, ok := (*cached)[normalizeLookupKey(entry.ID)]; ok && len(body) > 0 { + writeJSONBytes(w, http.StatusOK, body) + return + } + } + writeJSON(w, http.StatusOK, buildPublicModelPayload(entry)) } func (a *App) serveResponseByID(w http.ResponseWriter, path string) { @@ -913,8 +1322,13 @@ func attachConversationResponseMetadata(payload map[string]any, conversationID s } func (a *App) resolveContinuationConversation(r *http.Request, payload map[string]any, previousResponseID string, hiddenPrompt string, segments []conversationPromptSegment) (continuationTarget, bool) { - rawCount := sessionRawMessageCount(segments) explicitConversationID := requestedConversationID(r, payload) + explicitThreadID := requestedThreadID(r, payload) + return a.resolveContinuationConversationWithExplicit(previousResponseID, hiddenPrompt, segments, explicitConversationID, explicitThreadID) +} + +func (a *App) resolveContinuationConversationWithExplicit(previousResponseID string, hiddenPrompt string, segments []conversationPromptSegment, explicitConversationID string, explicitThreadID string) (continuationTarget, bool) { + rawCount := sessionRawMessageCount(segments) validateState := func(state *conversationContinuationState) bool { if state == nil { return true @@ -986,9 +1400,9 @@ func (a *App) resolveContinuationConversation(r *http.Request, payload map[strin } } } - if threadID := requestedThreadID(r, payload); threadID != "" { - if entry, ok := a.State.conversations().FindByThreadID(threadID); ok { - state, err := a.State.loadConversationContinuationStateByThreadID(threadID) + if explicitThreadID != "" { + if entry, ok := a.State.conversations().FindByThreadID(explicitThreadID); ok { + state, err := a.State.loadConversationContinuationStateByThreadID(explicitThreadID) if err == nil && !validateState(state) { return continuationTarget{}, false } @@ -998,9 +1412,9 @@ func (a *App) resolveContinuationConversation(r *http.Request, payload map[strin return continuationTarget{Conversation: entry}, true } target := continuationTarget{Conversation: ConversationEntry{ - ThreadID: threadID, + ThreadID: explicitThreadID, }} - if state, err := a.State.loadConversationContinuationStateByThreadID(threadID); err == nil { + if state, err := a.State.loadConversationContinuationStateByThreadID(explicitThreadID); err == nil { if !validateState(state) { return continuationTarget{}, false } @@ -1111,6 +1525,70 @@ func includeUsageInStream(payload map[string]any) bool { return includeUsage } +func decodeChatCompletionsRequestBodyFromRaw(raw []byte) (chatCompletionsRequestBody, map[string]any, error) { + typed, err := decodeTypedBodyFromRaw[chatCompletionsRequestBody](raw) + if err == nil { + return normalizeTypedChatCompletionsRequestBody(typed), nil, nil + } + payload, mapErr := decodeBodyMapFromRaw(raw) + if mapErr != nil { + return chatCompletionsRequestBody{}, nil, mapErr + } + return extractChatCompletionsRequestBody(payload), payload, nil +} + +func decodeResponsesRequestBodyFromRaw(raw []byte) (responsesRequestBody, map[string]any, error) { + typed, err := decodeTypedBodyFromRaw[responsesRequestBody](raw) + if err == nil { + return normalizeTypedResponsesRequestBody(typed), nil, nil + } + payload, mapErr := decodeBodyMapFromRaw(raw) + if mapErr != nil { + return responsesRequestBody{}, nil, mapErr + } + return extractResponsesRequestBody(payload), payload, nil +} + +func maybeSillyTavernByTypedMessages(rawMessages any) bool { + items := sliceValue(rawMessages) + if len(items) == 0 { + return false + } + systemPrompts := make([]string, 0, len(items)) + for _, raw := range items { + msg := mapValue(raw) + if msg == nil { + continue + } + if strings.TrimSpace(strings.ToLower(stringValue(msg["role"]))) != "system" { + continue + } + text := collapseWhitespace(flattenContent(msg["content"])) + if text != "" { + systemPrompts = append(systemPrompts, text) + } + } + if len(systemPrompts) == 0 { + return false + } + if looksLikeSillyTavernImpersonate(systemPrompts) || looksLikeSillyTavernQuiet(systemPrompts, nil) { + return true + } + for _, prompt := range systemPrompts { + lower := strings.ToLower(collapseWhitespace(prompt)) + if strings.Contains(lower, "fictional chat between") || + strings.Contains(lower, "[start a new chat]") || + strings.Contains(lower, "[continue your last message without repeating its original content.]") { + return true + } + } + return false +} + +func rawMayNeedSillyTavernPayloadFallback(raw []byte) bool { + return bytes.Contains(raw, []byte(`"continue_prefill"`)) || bytes.Contains(raw, []byte(`"show_thoughts"`)) +} + func chatCompletionInitialFlushDelayForRequest(request PromptRunRequest) time.Duration { if request.ClientProfile == sillyTavernClientProfile || request.StreamReasoningWarmup { return 0 @@ -1152,16 +1630,33 @@ func (a *App) runPromptStreamWithSink(r *http.Request, request PromptRunRequest, } func (a *App) handleChatCompletions(w http.ResponseWriter, r *http.Request) { - payload, err := decodeBody(r) + raw, err := a.decodeBodyRaw(w, r) if err != nil { - writeOpenAIError(w, http.StatusBadRequest, err.Error(), "invalid_request_error", nilString()) + writeInvalidBodyError(w, err) return } - if isLikelySillyTavernPayload(payload) { + typed, payload, err := decodeChatCompletionsRequestBodyFromRaw(raw) + if err != nil { + writeInvalidBodyError(w, err) + return + } + if payload == nil && (typed.likelySillyTavernByEnvelope() || maybeSillyTavernByTypedMessages(typed.Messages) || rawMayNeedSillyTavernPayloadFallback(raw)) { + payload, err = decodeBodyMapFromRaw(raw) + if err != nil { + writeInvalidBodyError(w, err) + return + } + } + if payload != nil && (typed.likelySillyTavernByEnvelope() || isLikelySillyTavernPayload(payload)) { a.handleSillyTavernChatCompletionsPayload(w, r, payload) return } - normalized, err := normalizeChatInput(payload) + messages := sliceValue(typed.Messages) + if len(messages) == 0 { + writeOpenAIError(w, http.StatusBadRequest, "messages must be an array", "invalid_request_error", nilString()) + return + } + normalized, err := normalizeChatInputFromParts(messages, typed.Attachments) if err != nil { writeOpenAIError(w, http.StatusBadRequest, err.Error(), "invalid_request_error", nilString()) return @@ -1171,7 +1666,12 @@ func (a *App) handleChatCompletions(w http.ResponseWriter, r *http.Request) { return } cfg, _, registry := a.State.Snapshot() - entry, err := registry.Resolve(requestedModel(payload, cfg.DefaultPublicModel()), cfg.DefaultPublicModel()) + requestedModelID := requestedModelFromTyped(typed.Model, cfg.DefaultPublicModel()) + useWebSearch := requestedWebSearchFromTyped(typed.UseWebSearch, typed.Metadata, typed.Tools, cfg.Features.UseWebSearch) + preferredConversationID := requestedConversationIDFromTyped(r, typed.ConversationID, typed.Conversation, typed.Metadata) + explicitThreadID := requestedThreadIDFromTyped(r, typed.ThreadID, typed.Thread, typed.NotionThreadID, typed.Metadata) + requestedAccount := requestedAccountEmailFromTyped(r, typed.AccountEmail, typed.NotionAccountEmail, typed.Metadata) + entry, err := registry.Resolve(requestedModelID, cfg.DefaultPublicModel()) if err != nil { writeOpenAIError(w, http.StatusBadRequest, err.Error(), "invalid_request_error", "model_not_found") return @@ -1187,17 +1687,16 @@ func (a *App) handleChatCompletions(w http.ResponseWriter, r *http.Request) { HiddenPrompt: hiddenPrompt, PublicModel: entry.ID, NotionModel: entry.NotionModel, - UseWebSearch: requestedWebSearch(payload, cfg.Features.UseWebSearch), + UseWebSearch: useWebSearch, Attachments: normalized.Attachments, SessionFingerprint: originalFingerprint, RawMessageCount: originalRawMessageCount, } freshThreadMode := forceFreshThreadPerRequest(cfg) - preferredConversationID := requestedConversationID(r, payload) conversation := ConversationEntry{} - if matched, ok := a.resolveContinuationConversation(r, payload, "", hiddenPrompt, normalized.Segments); ok { + if matched, ok := a.resolveContinuationConversationWithExplicit("", hiddenPrompt, normalized.Segments, preferredConversationID, explicitThreadID); ok { conversation = matched.Conversation - request.PinnedAccountEmail = firstNonEmpty(strings.TrimSpace(conversation.AccountEmail), requestedAccountEmail(r, payload)) + request.PinnedAccountEmail = firstNonEmpty(strings.TrimSpace(conversation.AccountEmail), requestedAccount) if freshThreadMode { request.ForceLocalConversationContinue = strings.TrimSpace(conversation.ID) != "" request.Prompt = buildFreshThreadReplayPromptFromConversation(conversation, latestPrompt, normalized.Attachments, promptText) @@ -1210,14 +1709,18 @@ func (a *App) handleChatCompletions(w http.ResponseWriter, r *http.Request) { request.Prompt = latestPrompt } } else { - request.PinnedAccountEmail = requestedAccountEmail(r, payload) + request.PinnedAccountEmail = requestedAccount } request.ConversationID = firstNonEmpty(strings.TrimSpace(conversation.ID), preferredConversationID) conversationID := a.startConversationTurn(conversation.ID, preferredConversationID, "api", "chat_completions", resolveRequestPromptForContinuation(normalized), request) setConversationIDHeader(w, conversationID) - stream, _ := payload["stream"].(bool) + stream := typed.Stream if stream { - a.writeChatCompletionLiveStream(w, r, request, entry.ID, includeUsageInStream(payload), conversationID) + includeUsage := false + if typed.StreamIncludeUsage != nil { + includeUsage = *typed.StreamIncludeUsage + } + a.writeChatCompletionLiveStream(w, r, request, entry.ID, includeUsage, conversationID) return } result, err := a.runPrompt(r, request) @@ -1237,9 +1740,9 @@ func (a *App) handleChatCompletions(w http.ResponseWriter, r *http.Request) { } func (a *App) handleSillyTavernChatCompletions(w http.ResponseWriter, r *http.Request) { - payload, err := decodeBody(r) + payload, err := a.decodeBody(w, r) if err != nil { - writeOpenAIError(w, http.StatusBadRequest, err.Error(), "invalid_request_error", nilString()) + writeInvalidBodyError(w, err) return } a.handleSillyTavernChatCompletionsPayload(w, r, payload) @@ -1349,14 +1852,19 @@ func (a *App) handleSillyTavernChatCompletionsPayload(w http.ResponseWriter, r * } func (a *App) handleResponses(w http.ResponseWriter, r *http.Request) { - payload, err := decodeBody(r) + raw, err := a.decodeBodyRaw(w, r) if err != nil { - writeOpenAIError(w, http.StatusBadRequest, err.Error(), "invalid_request_error", nilString()) + writeInvalidBodyError(w, err) return } - stream, _ := payload["stream"].(bool) + typed, _, err := decodeResponsesRequestBodyFromRaw(raw) + if err != nil { + writeInvalidBodyError(w, err) + return + } + stream := typed.Stream var previousResponse map[string]any - previousResponseID := strings.TrimSpace(stringValue(payload["previous_response_id"])) + previousResponseID := strings.TrimSpace(typed.PreviousResponseID) if previousResponseID != "" { var ok bool previousResponse, ok = a.State.getResponse(previousResponseID) @@ -1365,7 +1873,7 @@ func (a *App) handleResponses(w http.ResponseWriter, r *http.Request) { return } } - normalized, err := normalizeResponsesInput(payload, previousResponse) + normalized, err := normalizeResponsesInputFromParts(typed.Input, typed.Attachments, previousResponse) if err != nil { writeOpenAIError(w, http.StatusBadRequest, err.Error(), "invalid_request_error", nilString()) return @@ -1375,7 +1883,12 @@ func (a *App) handleResponses(w http.ResponseWriter, r *http.Request) { return } cfg, _, registry := a.State.Snapshot() - entry, err := registry.Resolve(requestedModel(payload, cfg.DefaultPublicModel()), cfg.DefaultPublicModel()) + requestedModelID := requestedModelFromTyped(typed.Model, cfg.DefaultPublicModel()) + useWebSearch := requestedWebSearchFromTyped(typed.UseWebSearch, typed.Metadata, typed.Tools, cfg.Features.UseWebSearch) + preferredConversationID := requestedConversationIDFromTyped(r, typed.ConversationID, typed.Conversation, typed.Metadata) + explicitThreadID := requestedThreadIDFromTyped(r, typed.ThreadID, typed.Thread, typed.NotionThreadID, typed.Metadata) + requestedAccount := requestedAccountEmailFromTyped(r, typed.AccountEmail, typed.NotionAccountEmail, typed.Metadata) + entry, err := registry.Resolve(requestedModelID, cfg.DefaultPublicModel()) if err != nil { writeOpenAIError(w, http.StatusBadRequest, err.Error(), "invalid_request_error", "model_not_found") return @@ -1391,17 +1904,16 @@ func (a *App) handleResponses(w http.ResponseWriter, r *http.Request) { HiddenPrompt: hiddenPrompt, PublicModel: entry.ID, NotionModel: entry.NotionModel, - UseWebSearch: requestedWebSearch(payload, cfg.Features.UseWebSearch), + UseWebSearch: useWebSearch, Attachments: normalized.Attachments, SessionFingerprint: originalFingerprint, RawMessageCount: originalRawMessageCount, } freshThreadMode := forceFreshThreadPerRequest(cfg) - preferredConversationID := requestedConversationID(r, payload) conversation := ConversationEntry{} - if matched, ok := a.resolveContinuationConversation(r, payload, previousResponseID, hiddenPrompt, normalized.Segments); ok { + if matched, ok := a.resolveContinuationConversationWithExplicit(previousResponseID, hiddenPrompt, normalized.Segments, preferredConversationID, explicitThreadID); ok { conversation = matched.Conversation - request.PinnedAccountEmail = firstNonEmpty(strings.TrimSpace(conversation.AccountEmail), requestedAccountEmail(r, payload)) + request.PinnedAccountEmail = firstNonEmpty(strings.TrimSpace(conversation.AccountEmail), requestedAccount) if freshThreadMode { request.ForceLocalConversationContinue = strings.TrimSpace(conversation.ID) != "" request.Prompt = buildFreshThreadReplayPromptFromConversation(conversation, latestPrompt, normalized.Attachments, promptText) @@ -1414,7 +1926,7 @@ func (a *App) handleResponses(w http.ResponseWriter, r *http.Request) { request.Prompt = latestPrompt } } else { - request.PinnedAccountEmail = requestedAccountEmail(r, payload) + request.PinnedAccountEmail = requestedAccount } if freshThreadMode && strings.TrimSpace(conversation.ID) == "" { request.Prompt = buildFreshThreadReplayPromptFromStoredResponse(normalized.PreviousResponsePrompt, latestPrompt, normalized.Attachments, request.Prompt) @@ -1468,11 +1980,11 @@ func (a *App) writeUpstreamError(w http.ResponseWriter, err error) { } func prepareOpenAISSEHeaders(w http.ResponseWriter) { + applyCORSHeaders(w) w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache, no-transform") w.Header().Set("Connection", "keep-alive") w.Header().Set("X-Accel-Buffering", "no") - w.Header().Set("Access-Control-Allow-Origin", "*") w.WriteHeader(http.StatusOK) } @@ -2106,7 +2618,13 @@ func (a *App) writeResponsesStream(w http.ResponseWriter, r *http.Request, resul } func (a *App) ServeHTTP(w http.ResponseWriter, r *http.Request) { + startedAt := time.Now() + statusCode := http.StatusOK + defer func() { + observeRequestDuration(r.URL.Path, r.Method, statusCode, time.Since(startedAt)) + }() safeWriter := &panicSafeResponseWriter{ResponseWriter: w} + applyCORSHeaders(safeWriter) defer func() { if recovered := recover(); recovered != nil { stack := strings.TrimSpace(string(debug.Stack())) @@ -2139,10 +2657,8 @@ func (a *App) ServeHTTP(w http.ResponseWriter, r *http.Request) { }() if r.Method == http.MethodOptions { - safeWriter.Header().Set("Access-Control-Allow-Origin", "*") - safeWriter.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Admin-Token") - safeWriter.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") safeWriter.WriteHeader(http.StatusNoContent) + statusCode = safeWriter.status return } @@ -2150,16 +2666,20 @@ func (a *App) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch { case r.Method == http.MethodGet && path == "/": a.serveIndex(safeWriter) + statusCode = safeWriter.status return case strings.HasPrefix(path, "/admin"): a.handleAdmin(safeWriter, r) + statusCode = safeWriter.status return case r.Method == http.MethodGet && path == "/healthz": a.serveHealthz(safeWriter) + statusCode = safeWriter.status return } if !a.authOK(safeWriter, r) { + statusCode = safeWriter.status return } @@ -2168,6 +2688,10 @@ func (a *App) ServeHTTP(w http.ResponseWriter, r *http.Request) { a.serveModels(safeWriter) case r.Method == http.MethodGet && strings.HasPrefix(path, "/v1/models/"): a.serveModelByID(safeWriter, path) + case r.Method == http.MethodGet && path == "/debug/vars": + expvar.Handler().ServeHTTP(safeWriter, r) + case r.Method == http.MethodGet && path == "/metrics": + writePrometheusMetrics(safeWriter) case r.Method == http.MethodGet && strings.HasPrefix(path, "/v1/responses/"): a.serveResponseByID(safeWriter, path) case r.Method == http.MethodPost && path == "/v1/st/chat/completions": @@ -2179,6 +2703,7 @@ func (a *App) ServeHTTP(w http.ResponseWriter, r *http.Request) { default: writeOpenAIError(safeWriter, http.StatusNotFound, "route not found", "invalid_request_error", "not_found") } + statusCode = safeWriter.status } func Main() { @@ -2190,6 +2715,14 @@ func Main() { app := &App{State: state} state.StartSessionRefreshLoop(context.Background()) app.StartEphemeralConversationCleanupLoop(context.Background()) + if cfg.Debug.PprofEnabled { + go func(addr string) { + log.Printf("[pprof] listening on http://%s/debug/pprof/ (local debug endpoint; avoid public exposure)", addr) + if err := http.ListenAndServe(addr, nil); err != nil { + log.Printf("[pprof] server stopped: %v", err) + } + }(cfg.Debug.PprofAddr) + } addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port) server := &http.Server{ Addr: addr, diff --git a/internal/app/main_fresh_thread_test.go b/internal/app/main_fresh_thread_test.go index fd2a3ad..7906df0 100644 --- a/internal/app/main_fresh_thread_test.go +++ b/internal/app/main_fresh_thread_test.go @@ -2,15 +2,21 @@ package app import ( "bytes" + "context" "encoding/json" "errors" + "expvar" "fmt" "net/http" "net/http/httptest" "os" "path/filepath" + "runtime" + "sort" + "strconv" "strings" "testing" + "time" ) func newFreshThreadTestApp(t *testing.T) *App { @@ -244,6 +250,1252 @@ func TestHandleSillyTavernFreshThreadReplaysLocalConversation(t *testing.T) { assertConversationContinued(t, app, seeded.ID, "thread-new-st", "The story continues.") } +func TestNormalizeConfigSetsPprofDefaults(t *testing.T) { + cfg := normalizeConfig(AppConfig{}) + if cfg.Debug.PprofEnabled { + t.Fatalf("expected pprof disabled by default") + } + if cfg.Debug.PprofAddr != "127.0.0.1:6060" { + t.Fatalf("unexpected default pprof addr: %q", cfg.Debug.PprofAddr) + } +} + +func TestDefaultConfigSurfHelperTransportDisabled(t *testing.T) { + cfg := defaultConfig() + if cfg.Features.UseSurfHelperTransport { + t.Fatalf("expected default use_surf_helper_transport=false") + } +} + +func TestNormalizeConfigKeepsSurfHelperTransportEnabled(t *testing.T) { + cfg := normalizeConfig(AppConfig{ + Features: FeatureConfig{UseSurfHelperTransport: true}, + }) + if !cfg.Features.UseSurfHelperTransport { + t.Fatalf("expected normalizeConfig to preserve use_surf_helper_transport=true") + } +} + +func TestDefaultConfigSetsDispatchProbeCacheTTLDefault(t *testing.T) { + cfg := defaultConfig() + if cfg.Dispatch.ProbeCacheTTLSeconds != 45 { + t.Fatalf("unexpected default dispatch probe cache ttl: %d", cfg.Dispatch.ProbeCacheTTLSeconds) + } +} + +func TestNormalizeConfigClampsNegativeDispatchProbeCacheTTL(t *testing.T) { + cfg := normalizeConfig(AppConfig{ + Dispatch: DispatchConfig{ProbeCacheTTLSeconds: -3}, + }) + if cfg.Dispatch.ProbeCacheTTLSeconds != 0 { + t.Fatalf("expected negative dispatch probe cache ttl to clamp to 0, got %d", cfg.Dispatch.ProbeCacheTTLSeconds) + } +} + +func TestDefaultConfigBrowserHelperPoolSizeDefaultZero(t *testing.T) { + cfg := defaultConfig() + if got := cfg.Browser.HelperPoolSize; got != 0 { + t.Fatalf("unexpected default browser helper pool size: got %d want %d", got, 0) + } +} + +func TestNormalizeConfigClampsBrowserHelperPoolSizeBounds(t *testing.T) { + negative := normalizeConfig(AppConfig{ + Browser: BrowserConfig{HelperPoolSize: -2}, + }) + if got := negative.Browser.HelperPoolSize; got != 0 { + t.Fatalf("expected negative helper pool size clamp to 0, got %d", got) + } + tooLarge := normalizeConfig(AppConfig{ + Browser: BrowserConfig{HelperPoolSize: 99}, + }) + if got := tooLarge.Browser.HelperPoolSize; got != 8 { + t.Fatalf("expected oversized helper pool size clamp to 8, got %d", got) + } +} + +func TestEmbeddedBrowserHelperAssetsRemoved(t *testing.T) { + _, err1 := os.Stat("internal/app/assets/browser-helper.cjs") + _, err2 := os.Stat("internal/app/assets/browser-login-helper.cjs") + if !errors.Is(err1, os.ErrNotExist) || !errors.Is(err2, os.ErrNotExist) { + t.Fatalf("node helper assets still exist") + } +} + +func TestSurfHelperTransportFeatureEnabledUsesSurfPath(t *testing.T) { + cfg := defaultConfig() + cfg.Features.UseSurfHelperTransport = true + if !cfg.Features.UseSurfHelperTransport { + t.Fatalf("expected surf flag enabled") + } +} + +func TestNormalizeConfigPrecomputesRetryPrefixes(t *testing.T) { + cfg := normalizeConfig(AppConfig{ + Prompt: PromptConfig{ + CodingRetryPrefixes: []string{"custom-coding-prefix"}, + GeneralRetryPrefixes: []string{"custom-general-prefix"}, + DirectAnswerRetryPrefixes: []string{"custom-direct-prefix"}, + }, + }) + if len(cfg.Prompt.precomputedAllRetryPrefixes) == 0 { + t.Fatalf("expected precomputed retry prefixes") + } + joined := strings.Join(cfg.Prompt.precomputedAllRetryPrefixes, "\n") + for _, required := range []string{ + "custom-coding-prefix", + "custom-general-prefix", + "custom-direct-prefix", + } { + if !strings.Contains(joined, required) { + t.Fatalf("precomputed retry prefixes missing %q", required) + } + } +} + +func TestEnsureAccountPathsSetsEmailKey(t *testing.T) { + cfg := normalizeConfig(AppConfig{ + LoginHelper: LoginHelperConfig{SessionsDir: "probe_files/notion_accounts"}, + }) + account := ensureAccountPaths(cfg, NotionAccount{Email: " Alice@Example.COM "}) + if account.emailKey != "alice@example.com" { + t.Fatalf("unexpected cached email key: %q", account.emailKey) + } +} + +func BenchmarkPromptGuardLooksLikeCodingRequest(b *testing.B) { + text := "Please help debug this golang function and refactor the docker deployment script." + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _ = promptGuardLooksLikeCodingRequest(text) + } +} + +func BenchmarkPromptGuardStripRetryPrefixes(b *testing.B) { + cfg := normalizeConfig(AppConfig{ + Prompt: PromptConfig{ + CodingRetryPrefixes: []string{"custom-coding-prefix"}, + GeneralRetryPrefixes: []string{"custom-general-prefix"}, + DirectAnswerRetryPrefixes: []string{"custom-direct-prefix"}, + }, + }) + base := "this is a coding request body" + input := cfg.Prompt.CodingRetryPrefixes[0] + cfg.Prompt.GeneralRetryPrefixes[0] + base + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _ = promptGuardStripRetryPrefixes(cfg, input) + } +} + +func BenchmarkServeModelsCaching(b *testing.B) { + cfg := defaultConfig() + cfg.APIKey = "bench-api-key" + cfg.Storage.SQLitePath = "" + state, err := newServerState(cfg) + if err != nil { + b.Fatalf("newServerState failed: %v", err) + } + defer func() { + _ = state.Close() + }() + app := &App{State: state} + req := httptest.NewRequest(http.MethodGet, "/v1/models", nil) + req.Header.Set("Authorization", "Bearer bench-api-key") + + b.Run("cached", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + b.Fatalf("unexpected status: got %d want %d", rec.Code, http.StatusOK) + } + } + }) + + b.Run("uncached_fallback", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + state.cachedModelsListJSON.Store(nil) + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + b.Fatalf("unexpected status: got %d want %d", rec.Code, http.StatusOK) + } + } + }) +} + +func BenchmarkDecodeChatCompletionsTypedFirst(b *testing.B) { + raw := []byte(`{ + "model":"gpt-5.4", + "stream":true, + "stream_options":{"include_usage":"1"}, + "messages":[ + {"role":"system","content":"You are helpful."}, + {"role":"user","content":"请总结这段文本并给出要点。"} + ], + "metadata":{"use_web_search":false} + }`) + b.ReportAllocs() + for i := 0; i < b.N; i++ { + typed, payload, err := decodeChatCompletionsRequestBodyFromRaw(raw) + if err != nil { + b.Fatalf("decodeChatCompletionsRequestBodyFromRaw failed: %v", err) + } + if payload != nil { + b.Fatalf("unexpected map fallback on typed benchmark path") + } + if len(sliceValue(typed.Messages)) == 0 { + b.Fatalf("expected typed messages") + } + } +} + +func BenchmarkDecodeChatCompletionsMapOnly(b *testing.B) { + raw := []byte(`{ + "model":"gpt-5.4", + "stream":true, + "stream_options":{"include_usage":"1"}, + "messages":[ + {"role":"system","content":"You are helpful."}, + {"role":"user","content":"请总结这段文本并给出要点。"} + ], + "metadata":{"use_web_search":false} + }`) + b.ReportAllocs() + for i := 0; i < b.N; i++ { + payload, err := decodeBodyMapFromRaw(raw) + if err != nil { + b.Fatalf("decodeBodyMapFromRaw failed: %v", err) + } + typed := extractChatCompletionsRequestBody(payload) + if len(sliceValue(typed.Messages)) == 0 { + b.Fatalf("expected map-extracted messages") + } + } +} + +func BenchmarkNormalizeChatInputFromTypedMessages(b *testing.B) { + raw := []byte(`{ + "messages":[ + {"role":"system","content":"You are helpful."}, + {"role":"user","content":[{"type":"text","text":"hello"},{"type":"text","text":"world"}]} + ], + "attachments":[{"type":"image_url","url":"https://example.com/a.png"}] + }`) + typed, _, err := decodeChatCompletionsRequestBodyFromRaw(raw) + if err != nil { + b.Fatalf("decodeChatCompletionsRequestBodyFromRaw failed: %v", err) + } + messages := sliceValue(typed.Messages) + if len(messages) == 0 { + b.Fatalf("expected typed messages") + } + b.ReportAllocs() + for i := 0; i < b.N; i++ { + normalized, err := normalizeChatInputFromParts(messages, typed.Attachments) + if err != nil { + b.Fatalf("normalizeChatInputFromParts failed: %v", err) + } + if normalized.Prompt == "" { + b.Fatalf("expected normalized prompt") + } + } +} + +func BenchmarkNormalizeChatInputFromMapMessages(b *testing.B) { + raw := []byte(`{ + "messages":[ + {"role":"system","content":"You are helpful."}, + {"role":"user","content":[{"type":"text","text":"hello"},{"type":"text","text":"world"}]} + ], + "attachments":[{"type":"image_url","url":"https://example.com/a.png"}] + }`) + payload, err := decodeBodyMapFromRaw(raw) + if err != nil { + b.Fatalf("decodeBodyMapFromRaw failed: %v", err) + } + messages := sliceValue(payload["messages"]) + if len(messages) == 0 { + b.Fatalf("expected map messages") + } + attachments := payload["attachments"] + b.ReportAllocs() + for i := 0; i < b.N; i++ { + normalized, err := normalizeChatInputFromParts(messages, attachments) + if err != nil { + b.Fatalf("normalizeChatInputFromParts failed: %v", err) + } + if normalized.Prompt == "" { + b.Fatalf("expected normalized prompt") + } + } +} + +func BenchmarkDecodeResponsesTypedFirst(b *testing.B) { + raw := []byte(`{ + "model":"gpt-5.4", + "stream":false, + "previous_response_id":"resp_123", + "input":[ + {"type":"input_text","text":"hello"}, + {"type":"input_text","text":"world"} + ], + "metadata":{"use_web_search":"1"}, + "attachments":[{"type":"file","file_url":"https://example.com/f.txt"}] + }`) + b.ReportAllocs() + for i := 0; i < b.N; i++ { + typed, payload, err := decodeResponsesRequestBodyFromRaw(raw) + if err != nil { + b.Fatalf("decodeResponsesRequestBodyFromRaw failed: %v", err) + } + if payload != nil { + b.Fatalf("unexpected map fallback on typed benchmark path") + } + if len(sliceValue(typed.Input)) == 0 { + b.Fatalf("expected typed input items") + } + } +} + +func BenchmarkDecodeResponsesMapOnly(b *testing.B) { + raw := []byte(`{ + "model":"gpt-5.4", + "stream":false, + "previous_response_id":"resp_123", + "input":[ + {"type":"input_text","text":"hello"}, + {"type":"input_text","text":"world"} + ], + "metadata":{"use_web_search":"1"}, + "attachments":[{"type":"file","file_url":"https://example.com/f.txt"}] + }`) + b.ReportAllocs() + for i := 0; i < b.N; i++ { + payload, err := decodeBodyMapFromRaw(raw) + if err != nil { + b.Fatalf("decodeBodyMapFromRaw failed: %v", err) + } + typed := extractResponsesRequestBody(payload) + if len(sliceValue(typed.Input)) == 0 { + b.Fatalf("expected map-extracted responses input") + } + } +} + +func BenchmarkNormalizeResponsesInputFromTyped(b *testing.B) { + raw := []byte(`{ + "input":[ + {"type":"input_text","text":"hello"}, + {"type":"input_text","text":"world"} + ], + "attachments":[{"type":"file","file_url":"https://example.com/f.txt"}] + }`) + typed, _, err := decodeResponsesRequestBodyFromRaw(raw) + if err != nil { + b.Fatalf("decodeResponsesRequestBodyFromRaw failed: %v", err) + } + b.ReportAllocs() + for i := 0; i < b.N; i++ { + normalized, err := normalizeResponsesInputFromParts(typed.Input, typed.Attachments, nil) + if err != nil { + b.Fatalf("normalizeResponsesInputFromParts failed: %v", err) + } + if normalized.Prompt == "" { + b.Fatalf("expected normalized prompt") + } + } +} + +func BenchmarkNormalizeResponsesInputFromMap(b *testing.B) { + raw := []byte(`{ + "input":[ + {"type":"input_text","text":"hello"}, + {"type":"input_text","text":"world"} + ], + "attachments":[{"type":"file","file_url":"https://example.com/f.txt"}] + }`) + payload, err := decodeBodyMapFromRaw(raw) + if err != nil { + b.Fatalf("decodeBodyMapFromRaw failed: %v", err) + } + b.ReportAllocs() + for i := 0; i < b.N; i++ { + normalized, err := normalizeResponsesInputFromParts(payload["input"], payload["attachments"], nil) + if err != nil { + b.Fatalf("normalizeResponsesInputFromParts failed: %v", err) + } + if normalized.Prompt == "" { + b.Fatalf("expected normalized prompt") + } + } +} + +func BenchmarkChatDecodeAndNormalizeTypedFirst(b *testing.B) { + raw := []byte(`{ + "model":"gpt-5.4", + "stream":false, + "messages":[ + {"role":"system","content":"You are helpful."}, + {"role":"user","content":[{"type":"text","text":"hello"},{"type":"text","text":"world"}]} + ], + "attachments":[{"type":"image_url","url":"https://example.com/a.png"}] + }`) + b.ReportAllocs() + for i := 0; i < b.N; i++ { + typed, payload, err := decodeChatCompletionsRequestBodyFromRaw(raw) + if err != nil { + b.Fatalf("decodeChatCompletionsRequestBodyFromRaw failed: %v", err) + } + if payload != nil { + b.Fatalf("unexpected map fallback on typed benchmark path") + } + normalized, err := normalizeChatInputFromParts(sliceValue(typed.Messages), typed.Attachments) + if err != nil { + b.Fatalf("normalizeChatInputFromParts failed: %v", err) + } + if normalized.Prompt == "" { + b.Fatalf("expected normalized prompt") + } + } +} + +func BenchmarkChatDecodeAndNormalizeMapOnly(b *testing.B) { + raw := []byte(`{ + "model":"gpt-5.4", + "stream":false, + "messages":[ + {"role":"system","content":"You are helpful."}, + {"role":"user","content":[{"type":"text","text":"hello"},{"type":"text","text":"world"}]} + ], + "attachments":[{"type":"image_url","url":"https://example.com/a.png"}] + }`) + b.ReportAllocs() + for i := 0; i < b.N; i++ { + payload, err := decodeBodyMapFromRaw(raw) + if err != nil { + b.Fatalf("decodeBodyMapFromRaw failed: %v", err) + } + typed := extractChatCompletionsRequestBody(payload) + normalized, err := normalizeChatInputFromParts(sliceValue(typed.Messages), typed.Attachments) + if err != nil { + b.Fatalf("normalizeChatInputFromParts failed: %v", err) + } + if normalized.Prompt == "" { + b.Fatalf("expected normalized prompt") + } + } +} + +func BenchmarkResponsesDecodeAndNormalizeTypedFirst(b *testing.B) { + raw := []byte(`{ + "model":"gpt-5.4", + "input":[ + {"type":"input_text","text":"hello"}, + {"type":"input_text","text":"world"} + ], + "attachments":[{"type":"file","file_url":"https://example.com/f.txt"}] + }`) + b.ReportAllocs() + for i := 0; i < b.N; i++ { + typed, payload, err := decodeResponsesRequestBodyFromRaw(raw) + if err != nil { + b.Fatalf("decodeResponsesRequestBodyFromRaw failed: %v", err) + } + if payload != nil { + b.Fatalf("unexpected map fallback on typed benchmark path") + } + normalized, err := normalizeResponsesInputFromParts(typed.Input, typed.Attachments, nil) + if err != nil { + b.Fatalf("normalizeResponsesInputFromParts failed: %v", err) + } + if normalized.Prompt == "" { + b.Fatalf("expected normalized prompt") + } + } +} + +func BenchmarkResponsesDecodeAndNormalizeMapOnly(b *testing.B) { + raw := []byte(`{ + "model":"gpt-5.4", + "input":[ + {"type":"input_text","text":"hello"}, + {"type":"input_text","text":"world"} + ], + "attachments":[{"type":"file","file_url":"https://example.com/f.txt"}] + }`) + b.ReportAllocs() + for i := 0; i < b.N; i++ { + payload, err := decodeBodyMapFromRaw(raw) + if err != nil { + b.Fatalf("decodeBodyMapFromRaw failed: %v", err) + } + typed := extractResponsesRequestBody(payload) + normalized, err := normalizeResponsesInputFromParts(typed.Input, typed.Attachments, nil) + if err != nil { + b.Fatalf("normalizeResponsesInputFromParts failed: %v", err) + } + if normalized.Prompt == "" { + b.Fatalf("expected normalized prompt") + } + } +} + +func TestServeModelsUsesStaticJSONCache(t *testing.T) { + app := newFreshThreadTestApp(t) + raw := []byte(`{"object":"list","data":[{"id":"cached-model","object":"model"}]}`) + ready := append([]byte(nil), raw...) + app.State.cachedModelsListJSON.Store(&ready) + req := httptest.NewRequest(http.MethodGet, "/v1/models", nil) + req.Header.Set("Authorization", "Bearer test-api-key") + rec := httptest.NewRecorder() + + app.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: got %d want %d", rec.Code, http.StatusOK) + } + if got := strings.TrimSpace(rec.Body.String()); got != string(raw) { + t.Fatalf("expected cached body, got %s", got) + } +} + +func TestServeModelByIDUsesStaticJSONCache(t *testing.T) { + app := newFreshThreadTestApp(t) + _, _, registry := app.State.Snapshot() + entry, err := registry.Resolve("gpt-5.4", "auto") + if err != nil { + t.Fatalf("resolve model failed: %v", err) + } + body := []byte(`{"id":"gpt-5.4","object":"model","cached":true}`) + cache := map[string][]byte{ + normalizeLookupKey(entry.ID): append([]byte(nil), body...), + } + app.State.cachedModelByIDJSON.Store(&cache) + req := httptest.NewRequest(http.MethodGet, "/v1/models/"+entry.ID, nil) + req.Header.Set("Authorization", "Bearer test-api-key") + rec := httptest.NewRecorder() + + app.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: got %d want %d", rec.Code, http.StatusOK) + } + if got := strings.TrimSpace(rec.Body.String()); got != string(body) { + t.Fatalf("expected cached body, got %s", got) + } +} + +func TestServeHealthzIncludesRefreshRuntimeFieldsWhenStaticCacheExists(t *testing.T) { + app := newFreshThreadTestApp(t) + static := []byte(`{"ok":true,"default_model":"gpt-5.4","model_count":3,"user_email":"user@example.com","space_id":"space-id","active_account":"acc@example.com","session_refresh_enabled":true}`) + staticCopy := append([]byte(nil), static...) + app.State.cachedHealthzStaticJSON.Store(&staticCopy) + app.State.mu.Lock() + app.State.LastSessionRefresh = time.Date(2026, time.January, 2, 3, 4, 5, 0, time.UTC) + app.State.LastSessionRefreshError = "refresh failed" + app.State.mu.Unlock() + req := httptest.NewRequest(http.MethodGet, "/healthz", nil) + rec := httptest.NewRecorder() + + app.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: got %d want %d", rec.Code, http.StatusOK) + } + var payload map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil { + t.Fatalf("unmarshal healthz failed: %v", err) + } + if got, _ := payload["default_model"].(string); got != "gpt-5.4" { + t.Fatalf("unexpected default_model: %q", got) + } + if got, ok := payload["session_ready"].(bool); !ok || got { + t.Fatalf("unexpected session_ready: %#v", payload["session_ready"]) + } + if got, _ := payload["last_session_refresh"].(string); got != "2026-01-02T03:04:05Z" { + t.Fatalf("unexpected last_session_refresh: %q", got) + } + if got, _ := payload["last_session_refresh_error"].(string); got != "refresh failed" { + t.Fatalf("unexpected last_session_refresh_error: %q", got) + } +} + +func TestServeHTTPDebugVarsExposesWreqClientMetric(t *testing.T) { + app := newFreshThreadTestApp(t) + before := int64(0) + if value := transportClientNewTotalMetric.Get("standard"); value != nil { + before = value.(*expvar.Int).Value() + } + transportClientNewTotalMetric.Add("standard", 1) + req := httptest.NewRequest(http.MethodGet, "/debug/vars", nil) + req.Header.Set("Authorization", "Bearer test-api-key") + rec := httptest.NewRecorder() + + app.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: got %d want %d body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + body := rec.Body.String() + if !strings.Contains(body, `"notion2api_transport_client_new_total"`) { + t.Fatalf("expected metrics payload to include wreq client metric, got %s", body) + } + if !strings.Contains(body, `"notion2api_http_transport_cache_total"`) { + t.Fatalf("expected metrics payload to include transport cache metric, got %s", body) + } + after := int64(0) + if value := transportClientNewTotalMetric.Get("standard"); value != nil { + after = value.(*expvar.Int).Value() + } + if after < before+1 { + t.Fatalf("expected metric value to be incremented, before=%d after=%d", before, after) + } +} + +func TestServeHTTPMetricsExposesCorePrometheusSeries(t *testing.T) { + resetMetricsForTest() + cfg := defaultConfig() + cfg.APIKey = "test-api-key" + cfg.Storage.SQLitePath = "" + state, err := newServerState(cfg) + if err != nil { + t.Fatalf("newServerState failed: %v", err) + } + defer func() { + _ = state.Close() + }() + app := &App{State: state} + + setDispatchSlotInflight("alice@example.com", 2) + observeTransportCallDuration(25 * time.Millisecond) + observeSQLiteOpDuration("save_response", 2*time.Millisecond) + addBrowserHelperSpawn() + addBrowserHelperPoolWorkerSpawn() + + warmReq := httptest.NewRequest(http.MethodGet, "/healthz", nil) + warmRec := httptest.NewRecorder() + app.ServeHTTP(warmRec, warmReq) + if warmRec.Code != http.StatusOK { + t.Fatalf("unexpected warm-up status: got %d want %d", warmRec.Code, http.StatusOK) + } + + req := httptest.NewRequest(http.MethodGet, "/metrics", nil) + req.Header.Set("Authorization", "Bearer test-api-key") + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: got %d want %d body=%s", rec.Code, http.StatusOK, rec.Body.String()) + } + body := rec.Body.String() + for _, want := range []string{ + "notion2api_request_duration_seconds_bucket", + "notion2api_dispatch_slot_inflight", + "notion2api_transport_call_duration_seconds_bucket", + "notion2api_browser_helper_spawn_total", + "notion2api_browser_helper_pool_worker_spawn_total", + "notion2api_sqlite_op_duration_seconds_bucket", + "notion2api_response_store_prune_total", + } { + if !strings.Contains(body, want) { + t.Fatalf("expected /metrics output to include %q, got: %s", want, body) + } + } + if !strings.Contains(body, "notion2api_browser_helper_pool_worker_spawn_total 1") { + t.Fatalf("expected pool worker spawn counter value to be 1, got: %s", body) + } +} + +func TestSnapshotReadsFromAtomicBundle(t *testing.T) { + state := &ServerState{} + cfg := defaultConfig() + cfg.APIKey = "snapshot-api-key" + session := SessionInfo{UserID: "user-1", SpaceID: "space-1"} + registry := ModelRegistry{ + Entries: []ModelDefinition{ + {ID: "gpt-5.4", Enabled: true}, + }, + } + + state.mu.Lock() + state.Config = cfg + state.Session = session + state.ModelRegistry = registry + state.updateSnapshotBundleLocked() + state.mu.Unlock() + + gotCfg, gotSession, gotRegistry := state.Snapshot() + if gotCfg.APIKey != cfg.APIKey { + t.Fatalf("snapshot cfg mismatch: got %q want %q", gotCfg.APIKey, cfg.APIKey) + } + if gotSession.UserID != session.UserID || gotSession.SpaceID != session.SpaceID { + t.Fatalf("snapshot session mismatch: got %+v want %+v", gotSession, session) + } + if len(gotRegistry.Entries) != 1 || gotRegistry.Entries[0].ID != "gpt-5.4" { + t.Fatalf("snapshot registry mismatch: %+v", gotRegistry.Entries) + } + if len(state.snap.Load().DispatchOrder) != 0 { + t.Fatalf("expected empty dispatch order for empty accounts") + } +} + +func TestSnapshotDispatchOrderPrecomputed(t *testing.T) { + tempDir := t.TempDir() + aliceProbe := filepath.Join(tempDir, "alice-probe.json") + bobProbe := filepath.Join(tempDir, "bob-probe.json") + if err := os.WriteFile(aliceProbe, []byte(`{"ok":true}`), 0o600); err != nil { + t.Fatalf("write alice probe failed: %v", err) + } + if err := os.WriteFile(bobProbe, []byte(`{"ok":true}`), 0o600); err != nil { + t.Fatalf("write bob probe failed: %v", err) + } + + cfg := defaultConfig() + cfg.APIKey = "snapshot-dispatch-order-api-key" + cfg.ActiveAccount = "bob@example.com" + cfg.Accounts = []NotionAccount{ + {Email: "alice@example.com", Priority: 10, MaxConcurrency: 1, ProbeJSON: aliceProbe}, + {Email: "bob@example.com", Priority: 1, MaxConcurrency: 1, ProbeJSON: bobProbe}, + {Email: "carol@example.com", Priority: 50, MaxConcurrency: 1, Disabled: true}, + } + cfg = normalizeConfig(cfg) + state, err := newServerState(cfg) + if err != nil { + t.Fatalf("newServerState failed: %v", err) + } + defer func() { + _ = state.Close() + }() + + snap := state.snap.Load() + if snap == nil { + t.Fatalf("expected non-nil snapshot bundle") + } + if len(snap.DispatchOrder) != 2 { + t.Fatalf("unexpected dispatch order length: got %d want 2", len(snap.DispatchOrder)) + } + if getAccountEmailKey(snap.DispatchOrder[0]) != "bob@example.com" { + t.Fatalf("expected active account first in precomputed dispatch order, got %q", snap.DispatchOrder[0].Email) + } + if getAccountEmailKey(snap.DispatchOrder[1]) != "alice@example.com" { + t.Fatalf("expected second candidate to be alice, got %q", snap.DispatchOrder[1].Email) + } +} + +func TestResolveDispatchCandidatesFromSnapshotUsesPrecomputedOrder(t *testing.T) { + cfg := normalizeConfig(AppConfig{ + APIKey: "test-api-key", + Accounts: []NotionAccount{ + {Email: "first@example.com", Priority: 10, MaxConcurrency: 1}, + {Email: "second@example.com", Priority: 20, MaxConcurrency: 1}, + }, + ActiveAccount: "first@example.com", + }) + now := time.Now() + bundle := &snapshotBundle{ + Config: cfg, + DispatchOrder: []NotionAccount{ + {Email: "second@example.com", Priority: 20, MaxConcurrency: 1}, + {Email: "first@example.com", Priority: 10, MaxConcurrency: 1}, + }, + } + candidates, err := resolveDispatchCandidatesFromSnapshot(bundle, PromptRunRequest{}, now) + if err != nil { + t.Fatalf("resolveDispatchCandidatesFromSnapshot failed: %v", err) + } + if len(candidates) != 2 { + t.Fatalf("unexpected candidates length: got %d want 2", len(candidates)) + } + if getAccountEmailKey(candidates[0]) != "second@example.com" || getAccountEmailKey(candidates[1]) != "first@example.com" { + t.Fatalf("unexpected candidate order from snapshot: %+v", candidates) + } +} + +func TestConversationStoreGetReturnsValueSnapshotAfterMutation(t *testing.T) { + store := newConversationStore() + created := store.Create(ConversationCreateRequest{ + PreferredID: "conv-value-snapshot", + Source: "api", + Transport: "chat_completions", + Model: "gpt-5.4", + Prompt: "hello", + }) + got1, ok := store.Get(created.ID) + if !ok { + t.Fatalf("expected created conversation to exist") + } + if got1.Status != "running" { + t.Fatalf("unexpected initial status: %q", got1.Status) + } + + store.Complete(created.ID, InferenceResult{ + Text: "done", + ThreadID: "thread-1", + AccountEmail: "alice@example.com", + }) + + got2, ok := store.Get(created.ID) + if !ok { + t.Fatalf("expected conversation after completion") + } + if got2.Status != "completed" { + t.Fatalf("unexpected status after complete: %q", got2.Status) + } + if got1.Status == got2.Status { + t.Fatalf("expected old value snapshot to remain unchanged, got1=%q got2=%q", got1.Status, got2.Status) + } +} + +func TestConversationStoreSummaryUsesCachedPreviewAfterMutations(t *testing.T) { + store := newConversationStore() + created := store.Create(ConversationCreateRequest{ + PreferredID: "conv-preview-cache", + Source: "api", + Transport: "chat_completions", + Model: "gpt-5.4", + Prompt: "first question", + }) + list1 := store.List() + if len(list1) == 0 { + t.Fatalf("expected list to have one entry") + } + if !strings.Contains(list1[0].Preview, "first question") { + t.Fatalf("unexpected initial preview: %q", list1[0].Preview) + } + + store.AppendAssistantDelta(created.ID, "assistant draft") + list2 := store.List() + if len(list2) == 0 { + t.Fatalf("expected list to have one entry after delta") + } + if !strings.Contains(list2[0].Preview, "assistant draft") { + t.Fatalf("expected preview to reflect assistant delta, got %q", list2[0].Preview) + } + + store.Complete(created.ID, InferenceResult{ + Text: "final assistant reply", + ThreadID: "thread-preview", + AccountEmail: "preview@example.com", + }) + list3 := store.List() + if len(list3) == 0 { + t.Fatalf("expected list to have one entry after complete") + } + if !strings.Contains(list3[0].Preview, "final assistant reply") { + t.Fatalf("expected preview to reflect completed assistant text, got %q", list3[0].Preview) + } +} + +func TestRequestedWebSearchFromTypedMetadataAndTools(t *testing.T) { + if got := requestedWebSearchFromTyped(nil, json.RawMessage(`{"use_web_search": true}`), nil, false); !got { + t.Fatalf("expected use_web_search=true from metadata to enable web search") + } + if got := requestedWebSearchFromTyped(nil, json.RawMessage(`{"notion_use_web_search":"false"}`), nil, true); got { + t.Fatalf("expected notion_use_web_search=false metadata to disable web search") + } + if got := requestedWebSearchFromTyped(nil, nil, json.RawMessage(`[{"type":"web_search_preview"}]`), false); !got { + t.Fatalf("expected web_search tool to enable web search") + } + if got := requestedWebSearchFromTyped(nil, map[string]any{"use_web_search": "1"}, nil, false); !got { + t.Fatalf("expected use_web_search=1 map metadata to enable web search") + } + if got := requestedWebSearchFromTyped(nil, nil, []map[string]any{{"type": "web_search_legacy"}}, false); !got { + t.Fatalf("expected web_search tool map slice to enable web search") + } +} + +func TestExtractTypedRequestBodies(t *testing.T) { + chatPayload := map[string]any{ + "model": "gpt-5.4", + "stream": true, + "stream_options": map[string]any{"include_usage": true}, + "conversation_id": "conv-typed-chat", + "account_email": "typed@example.com", + "use_web_search": "true", + "metadata": map[string]any{"notion_use_web_search": false}, + "attachments": []any{map[string]any{"type": "image_url", "url": "https://example.com/image.png"}}, + "messages": []any{map[string]any{"role": "user", "content": "hello"}}, + "type": "continue", + "user_name": "user", + "char_name": "char", + "group_names": []any{"g1"}, + "continue_prefill": "next", + "show_thoughts": true, + "notion_account_email": "typed2@example.com", + } + chatTyped := extractChatCompletionsRequestBody(chatPayload) + if chatTyped.Model != "gpt-5.4" || !chatTyped.Stream { + t.Fatalf("unexpected typed chat body: %+v", chatTyped) + } + if chatTyped.UseWebSearch == nil || !*chatTyped.UseWebSearch { + t.Fatalf("expected typed chat use_web_search=true") + } + if chatTyped.StreamIncludeUsage == nil || !*chatTyped.StreamIncludeUsage { + t.Fatalf("expected typed chat stream include_usage=true") + } + if _, ok := chatTyped.Attachments.([]any); !ok { + t.Fatalf("expected typed chat attachments to keep raw array type") + } + if _, ok := chatTyped.Messages.([]any); !ok { + t.Fatalf("expected typed chat messages to keep raw array type") + } + if !chatTyped.likelySillyTavernByEnvelope() { + t.Fatalf("expected chat body to be identified as likely sillytavern by envelope") + } + + respPayload := map[string]any{ + "model": "gpt-5.4", + "stream": false, + "previous_response_id": "resp_123", + "conversation_id": "conv-typed-responses", + "thread_id": "thread-typed", + "account_email": "resp@example.com", + "use_web_search": true, + "metadata": map[string]any{"use_web_search": true}, + "input": []any{map[string]any{"type": "text", "text": "input payload"}}, + "attachments": []any{map[string]any{"type": "file", "file_url": "https://example.com/file.txt"}}, + } + respTyped := extractResponsesRequestBody(respPayload) + if respTyped.Model != "gpt-5.4" || respTyped.Stream { + t.Fatalf("unexpected typed responses body: %+v", respTyped) + } + if respTyped.PreviousResponseID != "resp_123" || respTyped.ConversationID != "conv-typed-responses" { + t.Fatalf("unexpected typed responses ids: %+v", respTyped) + } + if respTyped.UseWebSearch == nil || !*respTyped.UseWebSearch { + t.Fatalf("expected typed responses use_web_search=true") + } + if _, ok := respTyped.Input.([]any); !ok { + t.Fatalf("expected typed responses input to keep raw array type") + } + if _, ok := respTyped.Attachments.([]any); !ok { + t.Fatalf("expected typed responses attachments to keep raw array type") + } +} + +func TestExtractChatTypedStreamIncludeUsageParsing(t *testing.T) { + fromRaw := extractChatCompletionsRequestBody(map[string]any{ + "stream_options": json.RawMessage(`{"include_usage":"1"}`), + }) + if fromRaw.StreamIncludeUsage == nil || !*fromRaw.StreamIncludeUsage { + t.Fatalf("expected stream include_usage to parse true from raw json string flag") + } + + fromMapFalse := extractChatCompletionsRequestBody(map[string]any{ + "stream_options": map[string]any{"include_usage": false}, + }) + if fromMapFalse.StreamIncludeUsage == nil { + t.Fatalf("expected stream include_usage pointer to be populated for explicit false") + } + if *fromMapFalse.StreamIncludeUsage { + t.Fatalf("expected stream include_usage=false from typed stream_options map") + } +} + +func TestRequestedIdentifiersFromTypedRespectHeaders(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + req.Header.Set("X-Conversation-ID", "header-conv") + req.Header.Set("X-Thread-ID", "header-thread") + req.Header.Set("X-Account-Email", "header@example.com") + + if got := requestedConversationIDFromTyped(req, "body-conv", "body-conv2", map[string]any{"conversation_id": "meta-conv"}); got != "header-conv" { + t.Fatalf("conversation id should prefer header, got %q", got) + } + if got := requestedThreadIDFromTyped(req, "body-thread", "body-thread2", "body-thread3", map[string]any{"thread_id": "meta-thread"}); got != "header-thread" { + t.Fatalf("thread id should prefer header, got %q", got) + } + if got := requestedAccountEmailFromTyped(req, "body@example.com", "body2@example.com", map[string]any{"account_email": "meta@example.com"}); got != "header@example.com" { + t.Fatalf("account email should prefer header, got %q", got) + } +} + +func TestRequestedIdentifiersFromTypedFallbackToMetadata(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + metadata := json.RawMessage(`{"conversation_id":"meta-conv","thread_id":"meta-thread","account_email":"meta@example.com"}`) + + if got := requestedConversationIDFromTyped(req, "", "", metadata); got != "meta-conv" { + t.Fatalf("conversation id should fallback to metadata, got %q", got) + } + if got := requestedThreadIDFromTyped(req, "", "", "", metadata); got != "meta-thread" { + t.Fatalf("thread id should fallback to metadata, got %q", got) + } + if got := requestedAccountEmailFromTyped(req, "", "", metadata); got != "meta@example.com" { + t.Fatalf("account email should fallback to metadata, got %q", got) + } +} + +func TestResolveContinuationConversationWithExplicitUsesTypedThreadID(t *testing.T) { + app := newFreshThreadTestApp(t) + seeded := seedCompletedConversation(t, app, "conv-typed-explicit", "Seed question", "Seed answer", "thread-explicit") + + segments := []conversationPromptSegment{ + {Role: "user", Text: "follow up"}, + } + + target, ok := app.resolveContinuationConversationWithExplicit("", "", segments, "", "thread-explicit") + if !ok { + t.Fatalf("expected explicit typed thread id to resolve continuation target") + } + if strings.TrimSpace(target.Conversation.ID) != seeded.ID { + t.Fatalf("unexpected resolved conversation id: got %q want %q", target.Conversation.ID, seeded.ID) + } + if strings.TrimSpace(target.Conversation.ThreadID) != "thread-explicit" { + t.Fatalf("unexpected resolved thread id: got %q", target.Conversation.ThreadID) + } +} + +func TestTypedEnvelopeExtractionFallsBackToLegacyWhenTypedFieldsMissing(t *testing.T) { + cfg := defaultConfig() + cfg.APIKey = "test-api-key" + cfg.Storage.SQLitePath = "" + state, err := newServerState(cfg) + if err != nil { + t.Fatalf("newServerState failed: %v", err) + } + defer func() { + _ = state.Close() + }() + app := &App{State: state} + + var captured PromptRunRequest + app.runPromptOverride = func(_ *http.Request, request PromptRunRequest) (InferenceResult, error) { + captured = request + return InferenceResult{ + Text: "typed fallback ok", + ThreadID: "thread-typed-fallback", + MessageID: "msg-typed-fallback", + TraceID: "trace-typed-fallback", + AccountEmail: "header@example.com", + }, nil + } + + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", mustJSONBody(t, map[string]any{ + "model": "gpt-5.4", + "messages": []map[string]any{ + {"role": "user", "content": "hello"}, + }, + })) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer test-api-key") + req.Header.Set("X-Account-Email", "header@example.com") + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: got %d body=%s", rec.Code, rec.Body.String()) + } + if captured.PinnedAccountEmail != "header@example.com" { + t.Fatalf("expected pinned account from header fallback path, got %q", captured.PinnedAccountEmail) + } + if captured.PublicModel != "gpt-5.4" { + t.Fatalf("expected resolved model from legacy payload path, got %q", captured.PublicModel) + } +} + +func sqliteWriterFallbackValue(reason string) int64 { + if strings.TrimSpace(reason) == "" { + return 0 + } + value := sqliteWriterFallbackTotalMetric.Get(reason) + if value == nil { + return 0 + } + counter, ok := value.(*expvar.Int) + if !ok || counter == nil { + return 0 + } + return counter.Value() +} + +func boolPtr(value bool) *bool { + return &value +} + +func TestSaveResponseWithAccountPersistsViaAsyncSQLiteWriter(t *testing.T) { + tempDir := t.TempDir() + cfg := defaultConfig() + cfg.APIKey = "test-api-key" + cfg.Storage.SQLitePath = filepath.Join(tempDir, "responses.sqlite") + cfg.Storage.PersistConversations = true + cfg.Storage.PersistResponses = boolPtr(true) + cfg.Responses.StoreTTLSeconds = 3600 + + state, err := newServerState(cfg) + if err != nil { + t.Fatalf("newServerState failed: %v", err) + } + defer func() { + _ = state.Close() + }() + + responseID := "resp_async_test_1" + payload := map[string]any{ + "id": responseID, + "object": "response", + "output": []any{ + map[string]any{ + "type": "message", + "content": []any{ + map[string]any{ + "type": "output_text", + "text": "hello from async sqlite writer", + }, + }, + }, + }, + } + state.saveResponseWithAccount(responseID, payload, "conv-async", "thread-async", "async@example.com") + + deadline := time.Now().Add(3 * time.Second) + for { + record, ok := state.getStoredResponse(responseID) + if ok && strings.TrimSpace(record.ThreadID) == "thread-async" { + break + } + if time.Now().After(deadline) { + t.Fatalf("response not visible in in-memory store before deadline") + } + time.Sleep(25 * time.Millisecond) + } + + readStore, err := openSQLiteStore(cfg) + if err != nil { + t.Fatalf("openSQLiteStore(read) failed: %v", err) + } + defer func() { + _ = readStore.Close() + }() + + waitUntil := time.Now().Add(3 * time.Second) + for { + rows, queryErr := readStore.db.Query(`SELECT payload_json, conversation_id, thread_id, account_email FROM responses WHERE response_id = ?`, responseID) + if queryErr != nil { + t.Fatalf("query persisted response failed: %v", queryErr) + } + found := false + var rawPayload string + var conversationID string + var threadID string + var accountEmail string + for rows.Next() { + found = true + if scanErr := rows.Scan(&rawPayload, &conversationID, &threadID, &accountEmail); scanErr != nil { + _ = rows.Close() + t.Fatalf("scan persisted response failed: %v", scanErr) + } + } + _ = rows.Close() + if found { + if strings.TrimSpace(conversationID) != "conv-async" { + t.Fatalf("conversation_id mismatch: got %q want %q", conversationID, "conv-async") + } + if strings.TrimSpace(threadID) != "thread-async" { + t.Fatalf("thread_id mismatch: got %q want %q", threadID, "thread-async") + } + if strings.TrimSpace(accountEmail) != "async@example.com" { + t.Fatalf("account_email mismatch: got %q want %q", accountEmail, "async@example.com") + } + if !strings.Contains(rawPayload, "hello from async sqlite writer") { + t.Fatalf("unexpected payload_json: %s", rawPayload) + } + break + } + if time.Now().After(waitUntil) { + t.Fatalf("persisted response not found before deadline") + } + time.Sleep(25 * time.Millisecond) + } +} + +func TestSQLiteWriterCloseFlushesQueuedResponseWrites(t *testing.T) { + tempDir := t.TempDir() + cfg := defaultConfig() + cfg.APIKey = "test-api-key" + cfg.Storage.SQLitePath = filepath.Join(tempDir, "close-flush.sqlite") + cfg.Storage.PersistConversations = true + cfg.Storage.PersistResponses = boolPtr(true) + cfg.Responses.StoreTTLSeconds = 3600 + + state, err := newServerState(cfg) + if err != nil { + t.Fatalf("newServerState failed: %v", err) + } + + total := 12 + for i := 0; i < total; i++ { + responseID := "resp_flush_" + strconv.Itoa(i) + state.saveResponseWithAccount(responseID, map[string]any{ + "id": responseID, + "object": "response", + "idx": i, + }, "conv-flush", "thread-flush", "flush@example.com") + } + + if err := state.Close(); err != nil { + t.Fatalf("state.Close failed: %v", err) + } + + readStore, err := openSQLiteStore(cfg) + if err != nil { + t.Fatalf("openSQLiteStore(read) failed: %v", err) + } + defer func() { + _ = readStore.Close() + }() + + row := readStore.db.QueryRow(`SELECT COUNT(1) FROM responses WHERE conversation_id = ? AND thread_id = ?`, "conv-flush", "thread-flush") + var persisted int + if scanErr := row.Scan(&persisted); scanErr != nil { + t.Fatalf("scan persisted count failed: %v", scanErr) + } + if persisted != total { + t.Fatalf("persisted response count mismatch after close flush: got %d want %d", persisted, total) + } +} + +func TestSQLiteWriterFallbackMetricRemainsStableUnderNormalLoad(t *testing.T) { + tempDir := t.TempDir() + cfg := defaultConfig() + cfg.APIKey = "test-api-key" + cfg.Storage.SQLitePath = filepath.Join(tempDir, "fallback-metric.sqlite") + cfg.Storage.PersistConversations = true + cfg.Storage.PersistResponses = boolPtr(true) + cfg.Responses.StoreTTLSeconds = 3600 + + beforeChannelFull := sqliteWriterFallbackValue("channel_full") + beforeUnavailable := sqliteWriterFallbackValue("writer_unavailable") + + state, err := newServerState(cfg) + if err != nil { + t.Fatalf("newServerState failed: %v", err) + } + defer func() { + _ = state.Close() + }() + + for i := 0; i < 8; i++ { + responseID := "resp_metric_" + strconv.Itoa(i) + state.saveResponseWithAccount(responseID, map[string]any{ + "id": responseID, + "object": "response", + "idx": i, + }, "conv-metric", "thread-metric", "metric@example.com") + } + + time.Sleep(250 * time.Millisecond) + + afterChannelFull := sqliteWriterFallbackValue("channel_full") + afterUnavailable := sqliteWriterFallbackValue("writer_unavailable") + if afterChannelFull != beforeChannelFull { + t.Fatalf("expected no channel_full fallback in normal load; before=%d after=%d", beforeChannelFull, afterChannelFull) + } + if afterUnavailable != beforeUnavailable { + t.Fatalf("expected no writer_unavailable fallback in normal load; before=%d after=%d", beforeUnavailable, afterUnavailable) + } +} + func TestHandleChatCompletionsFreshThreadContinuesExplicitConversationIDWithLatestUserOnly(t *testing.T) { app := newFreshThreadTestApp(t) @@ -369,6 +1621,53 @@ func TestServerStateSaveAndApplyRejectsEmptyAPIKey(t *testing.T) { } } +func TestServerStateSaveAndApplyInvalidatesDispatchProbeCacheOnActiveAccountChange(t *testing.T) { + cfg := defaultConfig() + cfg.APIKey = "test-api-key" + cfg.Storage.SQLitePath = "" + cfg.Accounts = []NotionAccount{ + { + Email: "alice@example.com", + ProbeJSON: "probe_files/notion_accounts/alice/probe.json", + UserID: "alice-user", + SpaceID: "alice-space", + ClientVersion: "v1", + }, + { + Email: "bob@example.com", + ProbeJSON: "probe_files/notion_accounts/bob/probe.json", + UserID: "bob-user", + SpaceID: "bob-space", + ClientVersion: "v1", + }, + } + cfg.ActiveAccount = "alice@example.com" + + state, err := newServerState(cfg) + if err != nil { + t.Fatalf("newServerState failed: %v", err) + } + defer func() { + _ = state.Close() + }() + if state.DispatchProbeCache == nil { + t.Fatalf("expected dispatch probe cache to be initialized") + } + state.DispatchProbeCache.markSuccess("alice@example.com", time.Now()) + if state.DispatchProbeCache.shouldProbe("alice@example.com", 45*time.Second, time.Now()) { + t.Fatalf("expected warm cache entry before active-account change") + } + + next := state.Config + next.ActiveAccount = "bob@example.com" + if err := state.SaveAndApply(next); err != nil { + t.Fatalf("SaveAndApply failed: %v", err) + } + if !state.DispatchProbeCache.shouldProbe("alice@example.com", 45*time.Second, time.Now()) { + t.Fatalf("expected cache invalidation after active-account switch") + } +} + func TestHandleChatCompletionsStreamWritesErrorAfterHeadersSent(t *testing.T) { app := newFreshThreadTestApp(t) app.runPromptStreamSinkOverride = func(_ *http.Request, _ PromptRunRequest, sink InferenceStreamSink) (InferenceResult, error) { @@ -401,6 +1700,47 @@ func TestHandleChatCompletionsStreamWritesErrorAfterHeadersSent(t *testing.T) { } } +func TestHandleChatCompletionsStreamIncludeUsageFromTypedMessages(t *testing.T) { + app := newFreshThreadTestApp(t) + app.runPromptStreamSinkOverride = func(_ *http.Request, _ PromptRunRequest, sink InferenceStreamSink) (InferenceResult, error) { + if sink.Text != nil { + if err := sink.Text("hello "); err != nil { + t.Fatalf("stream text write failed: %v", err) + } + if err := sink.Text("world"); err != nil { + t.Fatalf("stream text write failed: %v", err) + } + } + return InferenceResult{ + Text: "hello world", + Prompt: "hello world", + ThreadID: "thread-stream-usage", + MessageID: "msg-stream-usage", + }, nil + } + + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", mustJSONBody(t, map[string]any{ + "model": "gpt-5.4", + "stream": true, + "stream_options": map[string]any{"include_usage": true}, + "messages": []map[string]any{ + {"role": "user", "content": "hello"}, + }, + })) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer test-api-key") + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + + body := rec.Body.String() + if !strings.Contains(body, "\"usage\"") { + t.Fatalf("expected stream output to include usage chunk, got body=%s", body) + } + if !strings.Contains(body, "data: [DONE]") { + t.Fatalf("expected stream done marker, got body=%s", body) + } +} + func TestNormalizeConfigDefaultsAccountMaxConcurrencyToOne(t *testing.T) { cfg := normalizeConfig(AppConfig{ APIKey: "test-api-key", @@ -487,3 +1827,562 @@ func TestRunPromptWithAccountPoolReturnsCapacityErrorWhenAllSlotsOccupied(t *tes t.Fatalf("expected wrapped sentinel error, got %v", runErr) } } + +func TestRefreshSessionInvalidatesDispatchProbeCacheOnSuccess(t *testing.T) { + cfg := normalizeConfig(AppConfig{ + APIKey: "test-api-key", + Accounts: []NotionAccount{ + { + Email: "alice@example.com", + ProbeJSON: "/tmp/alice/probe.json", + StorageStatePath: "/tmp/alice/storage_state.json", + PendingStatePath: "/tmp/alice/pending_login.json", + UserID: "alice-user", + SpaceID: "alice-space", + UserName: "alice", + SpaceName: "alice-space-name", + ClientVersion: "v1", + Status: "ready", + }, + }, + ActiveAccount: "alice@example.com", + SessionRefresh: SessionRefreshConfig{ + Enabled: true, + RetryOnAuthError: true, + AutoSwitch: true, + }, + }) + state := &ServerState{ + Config: cfg, + Session: SessionInfo{UserID: "alice-user", SpaceID: "alice-space"}, + DispatchProbeCache: newProbeCache(), + ResponseStore: newResponseStore(45 * time.Second), + Conversations: newConversationStore(), + AdminTokens: map[string]time.Time{}, + AdminLoginAttempts: map[string]AdminLoginAttempt{}, + } + slot := &accountSlot{} + slot.max.Store(1) + slot.inflight.Store(0) + slotMap := map[string]*accountSlot{ + "alice@example.com": slot, + } + state.slots.Store(&slotMap) + syncDispatchSlotInflightFromSlots(slotMap) + state.DispatchProbeCache.markSuccess("alice@example.com", time.Now()) + + originalTryRefresh := testHookTryRefreshAccount + originalSaveAndApply := testHookSaveAndApply + defer func() { + testHookTryRefreshAccount = originalTryRefresh + testHookSaveAndApply = originalSaveAndApply + }() + + testHookTryRefreshAccount = func(ctx context.Context, cfg AppConfig, account NotionAccount) (AppConfig, error) { + account.Status = "ready" + account.LastError = "" + account.LastRefreshAt = time.Now().Format(time.RFC3339) + cfg.UpsertAccount(account) + return cfg, nil + } + testHookSaveAndApply = func(s *ServerState, cfg AppConfig) error { + s.mu.Lock() + defer s.mu.Unlock() + s.Config = cfg + s.updateSnapshotBundleLocked() + return nil + } + + if err := state.RefreshSession(context.Background(), "test_refresh_success"); err != nil { + t.Fatalf("refresh session failed: %v", err) + } + if !state.DispatchProbeCache.shouldProbe("alice@example.com", 45*time.Second, time.Now()) { + t.Fatalf("expected probe cache to be invalidated after refresh success") + } +} + +func newSQLiteStoreTestConfig(path string) AppConfig { + cfg := defaultConfig() + cfg.APIKey = "test-api-key" + cfg.Storage.SQLitePath = path + return cfg +} + +func TestOpenSQLiteStoreConfiguresReadWriteAndReadOnlyPools(t *testing.T) { + cfg := newSQLiteStoreTestConfig(filepath.Join(t.TempDir(), "notion2api.sqlite")) + store, err := openSQLiteStore(cfg) + if err != nil { + t.Fatalf("openSQLiteStore failed: %v", err) + } + defer func() { + _ = store.Close() + }() + if store.db == nil { + t.Fatalf("expected writable sqlite connection") + } + if store.roDB == nil { + t.Fatalf("expected read-only sqlite connection") + } + if got := store.db.Stats().MaxOpenConnections; got != 1 { + t.Fatalf("unexpected write db max open conns: got %d want 1", got) + } + wantReadConns := maxInt(2, runtime.NumCPU()) + if got := store.roDB.Stats().MaxOpenConnections; got != wantReadConns { + t.Fatalf("unexpected read db max open conns: got %d want %d", got, wantReadConns) + } +} + +func TestSQLiteStoreInitAppliesExtendedPragmas(t *testing.T) { + cfg := newSQLiteStoreTestConfig(filepath.Join(t.TempDir(), "notion2api.sqlite")) + store, err := openSQLiteStore(cfg) + if err != nil { + t.Fatalf("openSQLiteStore failed: %v", err) + } + defer func() { + _ = store.Close() + }() + + var mmapSize int64 + if err := store.db.QueryRow("PRAGMA mmap_size;").Scan(&mmapSize); err != nil { + t.Fatalf("query mmap_size failed: %v", err) + } + if mmapSize != 268435456 { + t.Fatalf("unexpected mmap_size: got %d want %d", mmapSize, int64(268435456)) + } + + var cacheSize int64 + if err := store.db.QueryRow("PRAGMA cache_size;").Scan(&cacheSize); err != nil { + t.Fatalf("query cache_size failed: %v", err) + } + if cacheSize != -65536 { + t.Fatalf("unexpected cache_size: got %d want %d", cacheSize, int64(-65536)) + } + + var tempStore int64 + if err := store.db.QueryRow("PRAGMA temp_store;").Scan(&tempStore); err != nil { + t.Fatalf("query temp_store failed: %v", err) + } + if tempStore != 2 { + t.Fatalf("unexpected temp_store: got %d want 2(memory)", tempStore) + } + + var autoCheckpoint int64 + if err := store.db.QueryRow("PRAGMA wal_autocheckpoint;").Scan(&autoCheckpoint); err != nil { + t.Fatalf("query wal_autocheckpoint failed: %v", err) + } + if autoCheckpoint != 1000 { + t.Fatalf("unexpected wal_autocheckpoint: got %d want 1000", autoCheckpoint) + } +} + +func TestSQLiteStoreReadOnlyConnectionRejectsWrites(t *testing.T) { + cfg := newSQLiteStoreTestConfig(filepath.Join(t.TempDir(), "notion2api.sqlite")) + store, err := openSQLiteStore(cfg) + if err != nil { + t.Fatalf("openSQLiteStore failed: %v", err) + } + defer func() { + _ = store.Close() + }() + _, err = store.roDB.Exec("CREATE TABLE read_only_write_should_fail(id INTEGER)") + if err == nil { + t.Fatalf("expected write on read-only connection to fail") + } + if !strings.Contains(strings.ToLower(err.Error()), "readonly") { + t.Fatalf("expected readonly error, got: %v", err) + } +} + +func TestSQLiteStoreLoadAccountsUsesReadOnlyConnection(t *testing.T) { + cfg := newSQLiteStoreTestConfig(filepath.Join(t.TempDir(), "notion2api.sqlite")) + store, err := openSQLiteStore(cfg) + if err != nil { + t.Fatalf("openSQLiteStore failed: %v", err) + } + defer func() { + _ = store.Close() + }() + + saveCfg := normalizeConfig(AppConfig{ + APIKey: "test-api-key", + Storage: StorageConfig{SQLitePath: cfg.Storage.SQLitePath}, + LoginHelper: LoginHelperConfig{SessionsDir: "probe_files/notion_accounts"}, + Accounts: []NotionAccount{{Email: "alice@example.com"}}, + ActiveAccount: "alice@example.com", + }) + if err := store.SaveAccounts(saveCfg); err != nil { + t.Fatalf("SaveAccounts failed: %v", err) + } + + if err := store.db.Close(); err != nil { + t.Fatalf("close write db failed: %v", err) + } + store.db = nil + accounts, activeAccount, ok, err := store.LoadAccounts() + if err != nil { + t.Fatalf("LoadAccounts failed: %v", err) + } + if !ok { + t.Fatalf("expected persisted accounts to be available") + } + if len(accounts) != 1 { + t.Fatalf("unexpected account count: got %d want 1", len(accounts)) + } + if getAccountEmailKey(accounts[0]) != "alice@example.com" { + t.Fatalf("unexpected loaded account email: %q", accounts[0].Email) + } + if canonicalEmailKey(activeAccount) != "alice@example.com" { + t.Fatalf("unexpected active account: %q", activeAccount) + } +} + +func TestServeHTTPOptionsReturnsCORSNoContent(t *testing.T) { + app := newFreshThreadTestApp(t) + req := httptest.NewRequest(http.MethodOptions, "/v1/models", nil) + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + + if rec.Code != http.StatusNoContent { + t.Fatalf("unexpected options status: got %d want %d", rec.Code, http.StatusNoContent) + } + if got := rec.Header().Get("Access-Control-Allow-Origin"); got != corsAllowOrigin { + t.Fatalf("unexpected Access-Control-Allow-Origin: got %q want %q", got, corsAllowOrigin) + } + if got := rec.Header().Get("Access-Control-Allow-Headers"); got != corsAllowHeaders { + t.Fatalf("unexpected Access-Control-Allow-Headers: got %q want %q", got, corsAllowHeaders) + } + if got := rec.Header().Get("Access-Control-Allow-Methods"); got != corsAllowMethods { + t.Fatalf("unexpected Access-Control-Allow-Methods: got %q want %q", got, corsAllowMethods) + } +} + +func TestServeIndexIncludesCORSHeaders(t *testing.T) { + app := newFreshThreadTestApp(t) + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: got %d want %d", rec.Code, http.StatusOK) + } + if got := rec.Header().Get("Access-Control-Allow-Origin"); got != corsAllowOrigin { + t.Fatalf("unexpected Access-Control-Allow-Origin: got %q want %q", got, corsAllowOrigin) + } + if got := rec.Header().Get("Access-Control-Allow-Headers"); got != corsAllowHeaders { + t.Fatalf("unexpected Access-Control-Allow-Headers: got %q want %q", got, corsAllowHeaders) + } + if got := rec.Header().Get("Access-Control-Allow-Methods"); got != corsAllowMethods { + t.Fatalf("unexpected Access-Control-Allow-Methods: got %q want %q", got, corsAllowMethods) + } +} + +func TestNormalizeConfigSetsMaxRequestBodyBytesDefault(t *testing.T) { + cfg := normalizeConfig(AppConfig{}) + if got := cfg.Limits.MaxRequestBodyBytes; got != 4*1024*1024 { + t.Fatalf("unexpected max request body bytes default: got %d want %d", got, int64(4*1024*1024)) + } +} + +func TestNormalizeConfigClampsNonPositiveMaxRequestBodyBytes(t *testing.T) { + cfg := normalizeConfig(AppConfig{Limits: LimitsConfig{MaxRequestBodyBytes: -1}}) + if got := cfg.Limits.MaxRequestBodyBytes; got != 4*1024*1024 { + t.Fatalf("unexpected max request body bytes clamp: got %d want %d", got, int64(4*1024*1024)) + } +} + +func TestHandleChatCompletionsRejectsTooLargeBody(t *testing.T) { + cfg := defaultConfig() + cfg.APIKey = "test-api-key" + cfg.Storage.SQLitePath = "" + cfg.Limits.MaxRequestBodyBytes = 128 + state, err := newServerState(cfg) + if err != nil { + t.Fatalf("newServerState failed: %v", err) + } + defer func() { + _ = state.Close() + }() + app := &App{State: state} + + oversizeText := strings.Repeat("x", 512) + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", mustJSONBody(t, map[string]any{ + "model": "gpt-5.4", + "messages": []map[string]any{ + {"role": "user", "content": oversizeText}, + }, + })) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer test-api-key") + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + + if rec.Code != http.StatusRequestEntityTooLarge { + t.Fatalf("unexpected status: got %d want %d body=%s", rec.Code, http.StatusRequestEntityTooLarge, rec.Body.String()) + } + body := rec.Body.String() + if !strings.Contains(body, `"code":"request_too_large"`) { + t.Fatalf("expected request_too_large code, got %s", body) + } + if !strings.Contains(body, `"type":"invalid_request_error"`) { + t.Fatalf("expected invalid_request_error type, got %s", body) + } +} + +func TestDecodeBodyRawWithLimitRejectsTrailingContent(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{"a":1} {"b":2}`)) + raw, err := decodeBodyRawWithLimit(nil, req, 0) + if err == nil { + t.Fatalf("expected trailing content error, got raw=%q", string(raw)) + } + if !strings.Contains(err.Error(), "invalid json") { + t.Fatalf("expected invalid json error, got %v", err) + } +} + +func TestDecodeBodyRawWithLimitNormalizesWhitespace(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(" \n\t {\"a\":1}\n\t ")) + raw, err := decodeBodyRawWithLimit(nil, req, 0) + if err != nil { + t.Fatalf("decodeBodyRawWithLimit failed: %v", err) + } + if got := strings.TrimSpace(string(raw)); got != "{\"a\":1}" { + t.Fatalf("unexpected normalized raw body: got %q", got) + } +} + +func TestDecodeBodyRawWithLimitTreatsEmptyBodyAsEmptyObject(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(" \n\t ")) + raw, err := decodeBodyRawWithLimit(nil, req, 0) + if err != nil { + t.Fatalf("decodeBodyRawWithLimit failed: %v", err) + } + if string(raw) != "{}" { + t.Fatalf("expected empty object for empty body, got %q", string(raw)) + } +} + +func TestDecodeChatCompletionsRequestBodyFromRawFallsBackToMapOnTypedDecodeMismatch(t *testing.T) { + raw := []byte(`{"model":"gpt-5.4","messages":[{"role":"user","content":"hello"}],"group_names":[1]}`) + typed, payload, err := decodeChatCompletionsRequestBodyFromRaw(raw) + if err != nil { + t.Fatalf("decodeChatCompletionsRequestBodyFromRaw failed: %v", err) + } + if payload == nil { + t.Fatalf("expected payload fallback map to be populated") + } + messages := sliceValue(typed.Messages) + if len(messages) != 1 { + t.Fatalf("expected typed messages recovered via map fallback, got len=%d", len(messages)) + } + msg := mapValue(messages[0]) + if strings.TrimSpace(stringValue(msg["content"])) != "hello" { + t.Fatalf("expected fallback-typed message content 'hello', got %#v", msg["content"]) + } +} + +func TestDecodeChatCompletionsRequestBodyFromRawParsesStreamIncludeUsageWithoutMapFallback(t *testing.T) { + raw := []byte(`{"model":"gpt-5.4","stream_options":{"include_usage":"1"},"messages":[{"role":"user","content":"hello"}]}`) + typed, payload, err := decodeChatCompletionsRequestBodyFromRaw(raw) + if err != nil { + t.Fatalf("decodeChatCompletionsRequestBodyFromRaw failed: %v", err) + } + if payload != nil { + t.Fatalf("expected typed decode path without map fallback") + } + if typed.StreamIncludeUsage == nil || !*typed.StreamIncludeUsage { + t.Fatalf("expected stream include_usage=true from typed decode path") + } +} + +func TestHandleChatCompletionsSillyTavernFallbackOnContinuePrefillKey(t *testing.T) { + app := newFreshThreadTestApp(t) + captured := PromptRunRequest{} + app.runPromptOverride = func(_ *http.Request, request PromptRunRequest) (InferenceResult, error) { + captured = request + return InferenceResult{ + Text: "ok", + ThreadID: "thread-st-fallback", + AccountEmail: "seed@example.com", + }, nil + } + + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(`{ + "model":"gpt-5.4", + "messages":[{"role":"user","content":"Hello there"}], + "continue_prefill":"...", + "group_names":[1] + }`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer test-api-key") + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: got %d body=%s", rec.Code, rec.Body.String()) + } + if captured.ClientProfile != sillyTavernClientProfile { + t.Fatalf("expected sillytavern client profile, got %q", captured.ClientProfile) + } + if strings.TrimSpace(captured.Prompt) == "" { + t.Fatalf("expected non-empty prompt for sillytavern fallback") + } +} + +func TestDecodeResponsesRequestBodyFromRawFallsBackToMapOnTypedDecodeMismatch(t *testing.T) { + raw := []byte(`{"model":"gpt-5.4","input":"hello","attachments":[{"type":"file","file_url":"https://example.com/f.txt"}],"conversation_id":1}`) + typed, payload, err := decodeResponsesRequestBodyFromRaw(raw) + if err != nil { + t.Fatalf("decodeResponsesRequestBodyFromRaw failed: %v", err) + } + if payload == nil { + t.Fatalf("expected payload fallback map to be populated") + } + if strings.TrimSpace(typed.Model) != "gpt-5.4" { + t.Fatalf("unexpected model after fallback: %q", typed.Model) + } + if strings.TrimSpace(flattenContent(typed.Input)) != "hello" { + t.Fatalf("expected fallback-typed input 'hello', got %#v", typed.Input) + } + atts := sliceValue(typed.Attachments) + if len(atts) != 1 { + t.Fatalf("expected one attachment after fallback, got %d", len(atts)) + } +} + +func TestHandleResponsesTypedFirstDecodeFallbackOnConversationIDTypeMismatch(t *testing.T) { + app := newFreshThreadTestApp(t) + seeded := seedCompletedConversation(t, app, "conv-responses-fallback", "Please remember this.", "Remembered.", "thread-old-responses-fallback") + + var captured PromptRunRequest + app.runPromptOverride = func(_ *http.Request, request PromptRunRequest) (InferenceResult, error) { + captured = request + return InferenceResult{ + Text: "Summary ready.", + ThreadID: "thread-new-responses-fallback", + MessageID: "msg-new-responses-fallback", + TraceID: "trace-new-responses-fallback", + AccountEmail: "seed@example.com", + }, nil + } + + req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(`{ + "model":"gpt-5.4", + "input":"Summarize that.", + "attachments":[{"type":"file","file_url":"https://example.com/f.txt"}], + "conversation_id":1 + }`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer test-api-key") + req.Header.Set("X-Conversation-ID", seeded.ID) + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status mismatch: got %d body=%s", rec.Code, rec.Body.String()) + } + if got := rec.Header().Get("X-Conversation-ID"); got != seeded.ID { + t.Fatalf("conversation header mismatch: got %q want %q", got, seeded.ID) + } + if !captured.ForceLocalConversationContinue { + t.Fatalf("expected ForceLocalConversationContinue to be enabled") + } + assertPromptContains(t, captured.Prompt, + "Continue the conversation using the transcript below.", + "[user]\nPlease remember this.", + "[assistant]\nRemembered.", + "[user]\nSummarize that.", + ) + assertConversationContinued(t, app, seeded.ID, "thread-new-responses-fallback", "Summary ready.") +} + +func TestCollectProbeModelPathsIncludesActiveAndAccountProbeJSON(t *testing.T) { + cfg := normalizeConfig(AppConfig{ + ProbeJSON: " probe_files/notion_accounts/active/probe.json ", + Accounts: []NotionAccount{ + {ProbeJSON: "probe_files/notion_accounts/alpha/probe.json"}, + {ProbeJSON: "probe_files/notion_accounts/alpha/probe.json"}, + {ProbeJSON: " probe_files/notion_accounts/beta/probe.json "}, + }, + }) + paths := collectProbeModelPaths(cfg) + for i := range paths { + paths[i] = strings.ReplaceAll(paths[i], "\\", "/") + } + sort.Strings(paths) + expected := []string{ + "probe_files/notion_accounts/active/probe.json", + "probe_files/notion_accounts/alpha/probe.json", + "probe_files/notion_accounts/beta/probe.json", + } + if len(paths) != len(expected) { + t.Fatalf("unexpected path count: got %d want %d (%v)", len(paths), len(expected), paths) + } + for i := range expected { + if paths[i] != expected[i] { + t.Fatalf("unexpected path[%d]: got %q want %q", i, paths[i], expected[i]) + } + } +} + +func TestBuildModelRegistryLoadsProbeModelsFromActiveAndAccountPaths(t *testing.T) { + dir := t.TempDir() + activeProbe := filepath.Join(dir, "active-probe.json") + accountProbe := filepath.Join(dir, "account-probe.json") + activeBlob := `{"models":[{"model":"active-model-raw","modelMessage":"Active Model","modelFamily":"openai","displayGroup":"fast","isDisabled":false,"markdownChat":{"beta":false},"workflow":{"finalModelName":"active-notion-model","beta":false},"customAgent":{"finalModelName":"","beta":false}}]}` + accountBlob := `{"models":[{"model":"account-model-raw","modelMessage":"Account Model","modelFamily":"anthropic","displayGroup":"intelligent","isDisabled":false,"markdownChat":{"beta":false},"workflow":{"finalModelName":"account-notion-model","beta":false},"customAgent":{"finalModelName":"","beta":false}}]}` + writeProbeFile := func(path string, blob string) { + payload := map[string]any{ + "email": "tester@example.com", + "userId": "user-id", + "spaceId": "space-id", + "clientVersion": "v1", + "embeddedModels": blob, + } + encoded, err := json.Marshal(payload) + if err != nil { + t.Fatalf("marshal probe payload failed: %v", err) + } + if err := os.WriteFile(path, encoded, 0o600); err != nil { + t.Fatalf("write probe payload failed: %v", err) + } + } + writeProbeFile(activeProbe, activeBlob) + writeProbeFile(accountProbe, accountBlob) + + cfg := normalizeConfig(AppConfig{ + ProbeJSON: activeProbe, + Accounts: []NotionAccount{ + {Email: "alpha@example.com", ProbeJSON: accountProbe}, + }, + }) + registry := buildModelRegistry(cfg) + if _, err := registry.Resolve("active-model", ""); err != nil { + t.Fatalf("expected active probe model to be loaded, got err=%v", err) + } + if _, err := registry.Resolve("account-model", ""); err != nil { + t.Fatalf("expected account probe model to be loaded, got err=%v", err) + } +} + +func TestDeleteAccountUsesCanonicalKeyComparison(t *testing.T) { + cfg := normalizeConfig(AppConfig{ + ActiveAccount: " Alice@Example.com ", + ProbeJSON: "probe_files/notion_accounts/alice/probe.json", + Accounts: []NotionAccount{ + {Email: "alice@example.com"}, + }, + }) + ok := cfg.DeleteAccount("ALICE@example.com") + if !ok { + t.Fatalf("expected delete to succeed") + } + if len(cfg.Accounts) != 0 { + t.Fatalf("expected accounts to be empty after delete, got %d", len(cfg.Accounts)) + } + if cfg.ActiveAccount != "" { + t.Fatalf("expected active account to be cleared, got %q", cfg.ActiveAccount) + } + if cfg.ProbeJSON != "" { + t.Fatalf("expected probe json to be cleared, got %q", cfg.ProbeJSON) + } +} diff --git a/internal/app/metrics.go b/internal/app/metrics.go new file mode 100644 index 0000000..e776879 --- /dev/null +++ b/internal/app/metrics.go @@ -0,0 +1,435 @@ +package app + +import ( + "expvar" + "fmt" + "net/http" + "sort" + "strconv" + "strings" + "sync" + "time" +) + +type histogramSeries struct { + count uint64 + sum float64 + buckets []uint64 +} + +func newHistogramSeries(bucketCount int) *histogramSeries { + if bucketCount < 0 { + bucketCount = 0 + } + return &histogramSeries{ + buckets: make([]uint64, bucketCount), + } +} + +func (s *histogramSeries) observe(seconds float64, bounds []float64) { + if s == nil { + return + } + if seconds < 0 { + seconds = 0 + } + s.count++ + s.sum += seconds + for idx, bound := range bounds { + if seconds <= bound { + s.buckets[idx]++ + } + } +} + +type requestDurationKey struct { + Path string + Method string + Status string +} + +type sqliteDurationKey struct { + Op string +} + +var requestDurationBuckets = []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10} +var transportCallDurationBuckets = []float64{0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5} +var sqliteOpDurationBuckets = []float64{0.0005, 0.001, 0.0025, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1} + +var ( + requestDurationMu sync.Mutex + requestDurationSeries = map[requestDurationKey]*histogramSeries{} + + dispatchInflightMu sync.Mutex + dispatchInflight = map[string]int64{} + + transportCallMu sync.Mutex + transportCallSeries = newHistogramSeries(len(transportCallDurationBuckets)) + + browserSpawnMu sync.Mutex + browserSpawnTotal uint64 + + browserPoolWorkerMu sync.Mutex + browserPoolWorkerTotal uint64 + + sqliteDurationMu sync.Mutex + sqliteDurationSeries = map[sqliteDurationKey]*histogramSeries{} +) + +func resetMetricsForTest() { + requestDurationMu.Lock() + requestDurationSeries = map[requestDurationKey]*histogramSeries{} + requestDurationMu.Unlock() + + dispatchInflightMu.Lock() + dispatchInflight = map[string]int64{} + dispatchInflightMu.Unlock() + + transportCallMu.Lock() + transportCallSeries = newHistogramSeries(len(transportCallDurationBuckets)) + transportCallMu.Unlock() + + browserSpawnMu.Lock() + browserSpawnTotal = 0 + browserSpawnMu.Unlock() + + browserPoolWorkerMu.Lock() + browserPoolWorkerTotal = 0 + browserPoolWorkerMu.Unlock() + + sqliteDurationMu.Lock() + sqliteDurationSeries = map[sqliteDurationKey]*histogramSeries{} + sqliteDurationMu.Unlock() +} + +func observeRequestDuration(path string, method string, status int, elapsed time.Duration) { + seconds := elapsed.Seconds() + if seconds < 0 { + seconds = 0 + } + key := requestDurationKey{ + Path: normalizeMetricsPathLabel(path), + Method: strings.ToUpper(strings.TrimSpace(method)), + Status: strconv.Itoa(status), + } + if key.Method == "" { + key.Method = "UNKNOWN" + } + requestDurationMu.Lock() + series := requestDurationSeries[key] + if series == nil { + series = newHistogramSeries(len(requestDurationBuckets)) + requestDurationSeries[key] = series + } + series.observe(seconds, requestDurationBuckets) + requestDurationMu.Unlock() +} + +func setDispatchSlotInflight(email string, inflight int) { + key := canonicalEmailKey(email) + if key == "" { + return + } + if inflight < 0 { + inflight = 0 + } + dispatchInflightMu.Lock() + dispatchInflight[key] = int64(inflight) + dispatchInflightMu.Unlock() +} + +func syncDispatchSlotInflightFromSlots(next map[string]*accountSlot) { + dispatchInflightMu.Lock() + defer dispatchInflightMu.Unlock() + for key := range dispatchInflight { + if _, ok := next[key]; !ok { + delete(dispatchInflight, key) + } + } + for key, slot := range next { + if slot == nil { + continue + } + inflight := slot.inflight.Load() + if inflight < 0 { + inflight = 0 + } + dispatchInflight[key] = int64(inflight) + } +} + +func observeTransportCallDuration(elapsed time.Duration) { + seconds := elapsed.Seconds() + if seconds < 0 { + seconds = 0 + } + transportCallMu.Lock() + transportCallSeries.observe(seconds, transportCallDurationBuckets) + transportCallMu.Unlock() +} + +func addBrowserHelperSpawn() { + browserSpawnMu.Lock() + browserSpawnTotal++ + browserSpawnMu.Unlock() +} + +func addBrowserHelperPoolWorkerSpawn() { + browserPoolWorkerMu.Lock() + browserPoolWorkerTotal++ + browserPoolWorkerMu.Unlock() +} + +func observeSQLiteOpDuration(op string, elapsed time.Duration) { + op = strings.TrimSpace(strings.ToLower(op)) + if op == "" { + op = "unknown" + } + seconds := elapsed.Seconds() + if seconds < 0 { + seconds = 0 + } + key := sqliteDurationKey{Op: op} + sqliteDurationMu.Lock() + series := sqliteDurationSeries[key] + if series == nil { + series = newHistogramSeries(len(sqliteOpDurationBuckets)) + sqliteDurationSeries[key] = series + } + series.observe(seconds, sqliteOpDurationBuckets) + sqliteDurationMu.Unlock() +} + +func normalizeMetricsPathLabel(path string) string { + clean := strings.TrimSpace(path) + if clean == "" { + return "unknown" + } + switch { + case clean == "/": + return "/" + case clean == "/healthz": + return "/healthz" + case clean == "/metrics": + return "/metrics" + case clean == "/debug/vars": + return "/debug/vars" + case strings.HasPrefix(clean, "/v1/models/"): + return "/v1/models/:id" + case clean == "/v1/models": + return "/v1/models" + case strings.HasPrefix(clean, "/v1/responses/"): + return "/v1/responses/:id" + case clean == "/v1/responses": + return "/v1/responses" + case clean == "/v1/chat/completions": + return "/v1/chat/completions" + case clean == "/v1/st/chat/completions": + return "/v1/st/chat/completions" + case strings.HasPrefix(clean, "/admin/accounts/"): + return "/admin/accounts/:id" + case strings.HasPrefix(clean, "/admin/conversations/"): + return "/admin/conversations/:id" + case strings.HasPrefix(clean, "/admin"): + return "/admin/*" + } + if strings.Count(clean, "/") >= 2 { + parts := strings.Split(clean, "/") + if len(parts) >= 3 { + return "/" + parts[1] + "/" + parts[2] + "/*" + } + } + return clean +} + +func writePrometheusMetrics(w http.ResponseWriter) { + if w == nil { + return + } + w.Header().Set("Content-Type", "text/plain; version=0.0.4; charset=utf-8") + w.WriteHeader(http.StatusOK) + + _, _ = fmt.Fprintln(w, "# HELP notion2api_request_duration_seconds HTTP request duration seconds by path/method/status.") + _, _ = fmt.Fprintln(w, "# TYPE notion2api_request_duration_seconds histogram") + writeRequestDurationHistogram(w) + + _, _ = fmt.Fprintln(w, "# HELP notion2api_dispatch_slot_inflight Current in-flight dispatch slots per account email.") + _, _ = fmt.Fprintln(w, "# TYPE notion2api_dispatch_slot_inflight gauge") + writeDispatchInflightGauge(w) + + _, _ = fmt.Fprintln(w, "# HELP notion2api_transport_call_duration_seconds Duration of transport helper calls in seconds.") + _, _ = fmt.Fprintln(w, "# TYPE notion2api_transport_call_duration_seconds histogram") + writeTransportCallHistogram(w) + + _, _ = fmt.Fprintln(w, "# HELP notion2api_browser_helper_spawn_total Total spawned browser helper subprocesses.") + _, _ = fmt.Fprintln(w, "# TYPE notion2api_browser_helper_spawn_total counter") + writeBrowserHelperSpawnCounter(w) + + _, _ = fmt.Fprintln(w, "# HELP notion2api_browser_helper_pool_worker_spawn_total Total spawned persistent browser helper pool workers.") + _, _ = fmt.Fprintln(w, "# TYPE notion2api_browser_helper_pool_worker_spawn_total counter") + writeBrowserHelperPoolWorkerSpawnCounter(w) + + _, _ = fmt.Fprintln(w, "# HELP notion2api_sqlite_op_duration_seconds SQLite operation durations in seconds by operation.") + _, _ = fmt.Fprintln(w, "# TYPE notion2api_sqlite_op_duration_seconds histogram") + writeSQLiteDurationHistogram(w) + + _, _ = fmt.Fprintln(w, "# HELP notion2api_response_store_prune_total Total number of pruned in-memory response entries by reason.") + _, _ = fmt.Fprintln(w, "# TYPE notion2api_response_store_prune_total counter") + writeResponseStorePruneCounter(w) +} + +func writeRequestDurationHistogram(w http.ResponseWriter) { + requestDurationMu.Lock() + seriesMap := make(map[requestDurationKey]*histogramSeries, len(requestDurationSeries)) + keys := make([]requestDurationKey, 0, len(requestDurationSeries)) + for key, series := range requestDurationSeries { + copySeries := *series + copySeries.buckets = append([]uint64(nil), series.buckets...) + seriesMap[key] = ©Series + keys = append(keys, key) + } + requestDurationMu.Unlock() + + sort.Slice(keys, func(i, j int) bool { + if keys[i].Path != keys[j].Path { + return keys[i].Path < keys[j].Path + } + if keys[i].Method != keys[j].Method { + return keys[i].Method < keys[j].Method + } + return keys[i].Status < keys[j].Status + }) + + for _, key := range keys { + series := seriesMap[key] + if series == nil { + continue + } + labelPrefix := fmt.Sprintf("path=\"%s\",method=\"%s\",status=\"%s\"", + escapePrometheusLabelValue(key.Path), + escapePrometheusLabelValue(key.Method), + escapePrometheusLabelValue(key.Status), + ) + writeHistogramSeries(w, "notion2api_request_duration_seconds", labelPrefix, requestDurationBuckets, series) + } +} + +func writeDispatchInflightGauge(w http.ResponseWriter) { + dispatchInflightMu.Lock() + type pair struct { + email string + value int64 + } + items := make([]pair, 0, len(dispatchInflight)) + for email, value := range dispatchInflight { + items = append(items, pair{email: email, value: value}) + } + dispatchInflightMu.Unlock() + + sort.Slice(items, func(i, j int) bool { return items[i].email < items[j].email }) + for _, item := range items { + _, _ = fmt.Fprintf(w, "notion2api_dispatch_slot_inflight{email=\"%s\"} %d\n", + escapePrometheusLabelValue(item.email), item.value) + } +} + +func writeTransportCallHistogram(w http.ResponseWriter) { + transportCallMu.Lock() + series := *transportCallSeries + series.buckets = append([]uint64(nil), transportCallSeries.buckets...) + transportCallMu.Unlock() + writeHistogramSeries(w, "notion2api_transport_call_duration_seconds", "", transportCallDurationBuckets, &series) +} + +func writeBrowserHelperSpawnCounter(w http.ResponseWriter) { + browserSpawnMu.Lock() + total := browserSpawnTotal + browserSpawnMu.Unlock() + _, _ = fmt.Fprintf(w, "notion2api_browser_helper_spawn_total %d\n", total) +} + +func writeBrowserHelperPoolWorkerSpawnCounter(w http.ResponseWriter) { + browserPoolWorkerMu.Lock() + total := browserPoolWorkerTotal + browserPoolWorkerMu.Unlock() + _, _ = fmt.Fprintf(w, "notion2api_browser_helper_pool_worker_spawn_total %d\n", total) +} + +func writeSQLiteDurationHistogram(w http.ResponseWriter) { + sqliteDurationMu.Lock() + seriesMap := make(map[sqliteDurationKey]*histogramSeries, len(sqliteDurationSeries)) + keys := make([]sqliteDurationKey, 0, len(sqliteDurationSeries)) + for key, series := range sqliteDurationSeries { + copySeries := *series + copySeries.buckets = append([]uint64(nil), series.buckets...) + seriesMap[key] = ©Series + keys = append(keys, key) + } + sqliteDurationMu.Unlock() + + sort.Slice(keys, func(i, j int) bool { return keys[i].Op < keys[j].Op }) + for _, key := range keys { + series := seriesMap[key] + if series == nil { + continue + } + labelPrefix := fmt.Sprintf("op=\"%s\"", escapePrometheusLabelValue(key.Op)) + writeHistogramSeries(w, "notion2api_sqlite_op_duration_seconds", labelPrefix, sqliteOpDurationBuckets, series) + } +} + +func writeResponseStorePruneCounter(w http.ResponseWriter) { + if w == nil { + return + } + entryVar := responseStorePruneTotalMetric.Get("expired_entries") + if entryVar == nil { + _, _ = fmt.Fprintln(w, "notion2api_response_store_prune_total{reason=\"expired_entries\"} 0") + return + } + entryValue, ok := entryVar.(*expvar.Int) + if !ok || entryValue == nil { + _, _ = fmt.Fprintln(w, "notion2api_response_store_prune_total{reason=\"expired_entries\"} 0") + return + } + _, _ = fmt.Fprintf(w, "notion2api_response_store_prune_total{reason=\"expired_entries\"} %d\n", entryValue.Value()) +} + +func writeHistogramSeries(w http.ResponseWriter, metricName string, baseLabels string, bounds []float64, series *histogramSeries) { + if w == nil || series == nil { + return + } + for idx, bound := range bounds { + le := strconv.FormatFloat(bound, 'g', -1, 64) + labels := withExtraLabel(baseLabels, "le", le) + _, _ = fmt.Fprintf(w, "%s_bucket{%s} %d\n", metricName, labels, series.buckets[idx]) + } + infLabels := withExtraLabel(baseLabels, "le", "+Inf") + _, _ = fmt.Fprintf(w, "%s_bucket{%s} %d\n", metricName, infLabels, series.count) + if baseLabels == "" { + _, _ = fmt.Fprintf(w, "%s_sum %s\n", metricName, formatFloat(series.sum)) + _, _ = fmt.Fprintf(w, "%s_count %d\n", metricName, series.count) + return + } + _, _ = fmt.Fprintf(w, "%s_sum{%s} %s\n", metricName, baseLabels, formatFloat(series.sum)) + _, _ = fmt.Fprintf(w, "%s_count{%s} %d\n", metricName, baseLabels, series.count) +} + +func withExtraLabel(base string, name string, value string) string { + extra := fmt.Sprintf("%s=\"%s\"", name, escapePrometheusLabelValue(value)) + if strings.TrimSpace(base) == "" { + return extra + } + return base + "," + extra +} + +func escapePrometheusLabelValue(value string) string { + value = strings.ReplaceAll(value, `\`, `\\`) + value = strings.ReplaceAll(value, "\n", `\n`) + value = strings.ReplaceAll(value, `"`, `\"`) + return value +} + +func formatFloat(value float64) string { + return strconv.FormatFloat(value, 'g', -1, 64) +} diff --git a/internal/app/models.go b/internal/app/models.go index 3ad7ae0..537290c 100644 --- a/internal/app/models.go +++ b/internal/app/models.go @@ -6,6 +6,7 @@ import ( "os" "sort" "strings" + "sync" "unicode" ) @@ -56,7 +57,7 @@ func builtinModelDefinitions() []ModelDefinition { func buildModelRegistry(cfg AppConfig) ModelRegistry { entries := builtinModelDefinitions() - if probeEntries := extractProbeModelDefinitions(cfg.ProbeJSON); len(probeEntries) > 0 { + if probeEntries := extractProbeModelDefinitions(collectProbeModelPaths(cfg)); len(probeEntries) > 0 { entries = mergeModelDefinitions(entries, probeEntries) } if len(cfg.Models) > 0 { @@ -98,6 +99,27 @@ func buildModelRegistry(cfg AppConfig) ModelRegistry { return ModelRegistry{Entries: entries, ByID: byID, AliasToID: aliasToID} } +func collectProbeModelPaths(cfg AppConfig) []string { + paths := make([]string, 0, len(cfg.Accounts)+1) + seen := map[string]struct{}{} + appendPath := func(path string) { + clean := strings.TrimSpace(path) + if clean == "" { + return + } + if _, exists := seen[clean]; exists { + return + } + seen[clean] = struct{}{} + paths = append(paths, clean) + } + appendPath(cfg.ProbeJSON) + for _, account := range cfg.Accounts { + appendPath(account.ProbeJSON) + } + return paths +} + func (r ModelRegistry) Resolve(value string, fallback string) (ModelDefinition, error) { candidate := strings.TrimSpace(value) if candidate == "" { @@ -118,7 +140,54 @@ func (r ModelRegistry) Resolve(value string, fallback string) (ModelDefinition, return ModelDefinition{}, fmt.Errorf("unknown model: %s", candidate) } -func extractProbeModelDefinitions(path string) []ModelDefinition { +func extractProbeModelDefinitions(paths []string) []ModelDefinition { + if len(paths) == 0 { + return nil + } + parseConcurrency := len(paths) + if parseConcurrency > 4 { + parseConcurrency = 4 + } + if parseConcurrency < 1 { + parseConcurrency = 1 + } + type indexedResult struct { + index int + items []ModelDefinition + } + results := make([]indexedResult, len(paths)) + sem := make(chan struct{}, parseConcurrency) + var wg sync.WaitGroup + for i, path := range paths { + wg.Add(1) + go func(index int, probePath string) { + defer wg.Done() + sem <- struct{}{} + defer func() { <-sem }() + results[index] = indexedResult{ + index: index, + items: extractProbeModelDefinitionsFromPath(probePath), + } + }(i, path) + } + wg.Wait() + + seen := map[string]struct{}{} + out := make([]ModelDefinition, 0) + for _, result := range results { + for _, item := range result.items { + key := item.ID + "|" + item.NotionModel + if _, exists := seen[key]; exists { + continue + } + seen[key] = struct{}{} + out = append(out, item) + } + } + return out +} + +func extractProbeModelDefinitionsFromPath(path string) []ModelDefinition { if strings.TrimSpace(path) == "" { return nil } diff --git a/internal/app/notion_client.go b/internal/app/notion_client.go index 38f6087..b71a48c 100644 --- a/internal/app/notion_client.go +++ b/internal/app/notion_client.go @@ -8,6 +8,7 @@ import ( "crypto/tls" "encoding/json" "errors" + "expvar" "fmt" "io" "log" @@ -19,6 +20,7 @@ import ( "regexp" "strconv" "strings" + "sync" "time" ) @@ -36,6 +38,31 @@ const ( var leadingLangTagPattern = regexp.MustCompile(`(?is)^\s*(?:]*>|)\s*`) var prefixedTranscriptStepIDPattern = regexp.MustCompile(`^(?:cfg|ctx|upd)_([0-9a-fA-F]{32})$`) +var notionHTTPTransportCacheMetric = expvar.NewMap("notion2api_http_transport_cache_total") + +type notionHTTPTransportCacheKey struct { + UpstreamBaseURL string + UpstreamOriginURL string + UpstreamHostHeader string + UpstreamTLSServerName string + UpstreamUseEnvProxy bool + ProxyMode string + ProxyURL string + ProxyHTTPURL string + ProxyHTTPSURL string + ResinEnabled bool + ResinURL string + ResinPlatform string + ResinMode string + AccountEmailKey string +} + +var notionTransportCache = struct { + mu sync.RWMutex + items map[notionHTTPTransportCacheKey]*http.Transport +}{ + items: map[notionHTTPTransportCacheKey]*http.Transport{}, +} func bestEffortTimeout(parent context.Context, cap time.Duration) time.Duration { if cap <= 0 { @@ -385,12 +412,13 @@ type ndjsonPatchOperation struct { V any `json:"v"` } -type ndjsonEnvelope struct { - Type string `json:"type"` - Data map[string]any `json:"data,omitempty"` - Version int `json:"version,omitempty"` - V []ndjsonPatchOperation `json:"v,omitempty"` - RecordMap map[string]any `json:"recordMap,omitempty"` +type ndjsonStreamLine struct { + Type string `json:"type"` + V []ndjsonPatchOperation `json:"v,omitempty"` + RecordMap map[string]any `json:"recordMap,omitempty"` + ID string `json:"id,omitempty"` + FinishedAt any `json:"finishedAt,omitempty"` + Value []ndjsonAgentInferenceValue `json:"value,omitempty"` } type ndjsonAgentInferenceValue struct { @@ -623,10 +651,37 @@ func newNotionAIStreamingClient(session SessionInfo, cfg AppConfig, accountEmail return newNotionAIClientWithMode(session, cfg, accountEmail, true) } -func newNotionAIClientWithMode(session SessionInfo, cfg AppConfig, accountEmail string, streaming bool) *NotionAIClient { +func buildNotionHTTPTransportCacheKey(cfg AppConfig, accountEmail string) notionHTTPTransportCacheKey { normalizedCfg := normalizeConfig(cfg) - resolver := NewProxyResolver(normalizedCfg) upstream := normalizedCfg.NotionUpstream() + policy := normalizedCfg.ResolveProxyPolicyForAccount(accountEmail) + return notionHTTPTransportCacheKey{ + UpstreamBaseURL: strings.TrimSpace(upstream.BaseURL), + UpstreamOriginURL: strings.TrimSpace(upstream.OriginURL), + UpstreamHostHeader: strings.TrimSpace(upstream.HostHeader), + UpstreamTLSServerName: strings.TrimSpace(upstream.TLSServerName), + UpstreamUseEnvProxy: upstream.UseEnvProxy, + ProxyMode: strings.TrimSpace(policy.Mode), + ProxyURL: strings.TrimSpace(policy.URL), + ProxyHTTPURL: strings.TrimSpace(policy.HTTPURL), + ProxyHTTPSURL: strings.TrimSpace(policy.HTTPSURL), + ResinEnabled: policy.Resin.Enabled, + ResinURL: strings.TrimSpace(policy.Resin.URL), + ResinPlatform: strings.TrimSpace(policy.Resin.Platform), + ResinMode: strings.TrimSpace(policy.Resin.Mode), + AccountEmailKey: canonicalEmailKey(accountEmail), + } +} + +func cachedNotionHTTPTransport(cfg AppConfig, accountEmail string, resolver *ProxyResolver, upstream NotionUpstream) *http.Transport { + key := buildNotionHTTPTransportCacheKey(cfg, accountEmail) + notionTransportCache.mu.RLock() + cached := notionTransportCache.items[key] + notionTransportCache.mu.RUnlock() + if cached != nil { + notionHTTPTransportCacheMetric.Add("hit_rlock", 1) + return cached + } tlsConfig := &tls.Config{InsecureSkipVerify: true} if strings.TrimSpace(upstream.TLSServerName) != "" { tlsConfig.ServerName = strings.TrimSpace(upstream.TLSServerName) @@ -650,6 +705,23 @@ func newNotionAIClientWithMode(session SessionInfo, cfg AppConfig, accountEmail return proxyFunc(req) }, } + notionTransportCache.mu.Lock() + if existing := notionTransportCache.items[key]; existing != nil { + notionTransportCache.mu.Unlock() + notionHTTPTransportCacheMetric.Add("hit_lock", 1) + return existing + } + notionTransportCache.items[key] = transport + notionTransportCache.mu.Unlock() + notionHTTPTransportCacheMetric.Add("miss_new", 1) + return transport +} + +func newNotionAIClientWithMode(session SessionInfo, cfg AppConfig, accountEmail string, streaming bool) *NotionAIClient { + normalizedCfg := normalizeConfig(cfg) + resolver := NewProxyResolver(normalizedCfg) + upstream := normalizedCfg.NotionUpstream() + transport := cachedNotionHTTPTransport(normalizedCfg, accountEmail, resolver, upstream) timeout := requestTimeout(normalizedCfg) clientTimeout := timeout if streaming { @@ -1083,7 +1155,9 @@ func (c *NotionAIClient) runInferenceTranscriptWithFallback(ctx context.Context, if c.Config.DebugUpstream { log.Printf("[debug_upstream] runInferenceTranscript http start thread_id=%s", threadID) } + callStartedAt := time.Now() parsed, err := c.runInferenceTranscriptHTTP(ctx, payload, threadID, sink) + observeTransportCallDuration(time.Since(callStartedAt)) if c.Config.DebugUpstream { log.Printf("[debug_upstream] runInferenceTranscript http done thread_id=%s line_count=%d message_ids=%d err=%v", threadID, parsed.LineCount, len(parsed.MessageIDs), err) } @@ -2336,26 +2410,28 @@ func (s *ndjsonTranscriptState) handleLine(line []byte, threadID string, sink In if len(line) == 0 { return nil } - var envelope ndjsonEnvelope - if err := json.Unmarshal(line, &envelope); err != nil { + var streamLine ndjsonStreamLine + if err := json.Unmarshal(line, &streamLine); err != nil { return err } s.LineCount++ - switch envelope.Type { + switch streamLine.Type { case "patch": - for _, op := range envelope.V { + for _, op := range streamLine.V { if err := s.applyPatchOperation(op, sink); err != nil { return err } } case "agent-inference": - var event ndjsonAgentInferenceEvent - if err := json.Unmarshal(line, &event); err != nil { - return err + event := ndjsonAgentInferenceEvent{ + Type: streamLine.Type, + ID: streamLine.ID, + FinishedAt: streamLine.FinishedAt, + Value: streamLine.Value, } return s.mergeAgentInferenceEvent(event, sink) case "record-map": - messageIDs, agent, outcomeErr, ok := finalThreadOutcomeFromRecordMap(envelope.RecordMap, threadID) + messageIDs, agent, outcomeErr, ok := finalThreadOutcomeFromRecordMap(streamLine.RecordMap, threadID) if len(messageIDs) > 0 { s.MessageIDs = messageIDs } @@ -2395,52 +2471,79 @@ func (s *ndjsonTranscriptState) result() ndjsonParseResult { func consumeNDJSONStream(reader io.Reader, threadID string, sink InferenceStreamSink) (ndjsonParseResult, error) { state := &ndjsonTranscriptState{ActiveAgentIndex: -1} - buffered := bufio.NewReader(reader) - for { - line, err := buffered.ReadBytes('\n') - if len(line) > 0 { - if handleErr := state.handleLine(line, threadID, sink); handleErr != nil { - return state.result(), handleErr - } - if state.hasTerminalAnswer() { - return state.result(), nil - } + scanner := newNDJSONScanner(reader) + for scanner.Scan() { + line := scanner.Bytes() + if handleErr := state.handleLine(line, threadID, sink); handleErr != nil { + return state.result(), handleErr } - if err != nil { - if errors.Is(err, io.EOF) { - break - } - return state.result(), err + if state.hasTerminalAnswer() { + return state.result(), nil } } + if err := normalizeNDJSONScanError(scanner.Err()); err != nil { + return state.result(), err + } return state.result(), nil } var ndjsonIdleAfterAnswerTimeout = 5 * time.Second +var errNDJSONLineTooLarge = errors.New("ndjson line too large") + +const ( + ndjsonScannerInitialBuffer = 64 * 1024 + ndjsonMaxLineBytes = 16 * 1024 * 1024 +) type ndjsonReadEvent struct { line []byte err error } +func newNDJSONScanner(reader io.Reader) *bufio.Scanner { + scanner := bufio.NewScanner(reader) + scanner.Buffer(make([]byte, 0, ndjsonScannerInitialBuffer), ndjsonMaxLineBytes) + return scanner +} + +func normalizeNDJSONScanError(err error) error { + if err == nil { + return nil + } + if errors.Is(err, bufio.ErrTooLong) { + return fmt.Errorf("%w: exceeds %d bytes", errNDJSONLineTooLarge, ndjsonMaxLineBytes) + } + return err +} + func consumeNDJSONStreamWithIdleClose(reader io.ReadCloser, threadID string, sink InferenceStreamSink, idleAfterAnswer time.Duration) (ndjsonParseResult, error) { state := &ndjsonTranscriptState{ActiveAgentIndex: -1} - buffered := bufio.NewReader(reader) events := make(chan ndjsonReadEvent, 1) done := make(chan struct{}) defer close(done) go func() { - for { - line, err := buffered.ReadBytes('\n') + scanner := newNDJSONScanner(reader) + for scanner.Scan() { + line := append([]byte(nil), scanner.Bytes()...) select { - case events <- ndjsonReadEvent{line: line, err: err}: + case events <- ndjsonReadEvent{line: line}: case <-done: return } - if err != nil { + } + if err := normalizeNDJSONScanError(scanner.Err()); err != nil { + select { + case events <- ndjsonReadEvent{err: err}: + case <-done: return } + return + } + select { + case events <- ndjsonReadEvent{err: io.EOF}: + case <-done: + return } }() diff --git a/internal/app/notion_client_best_effort_test.go b/internal/app/notion_client_best_effort_test.go index bb04158..031629c 100644 --- a/internal/app/notion_client_best_effort_test.go +++ b/internal/app/notion_client_best_effort_test.go @@ -1,8 +1,11 @@ package app import ( + "bytes" "context" "encoding/json" + "errors" + "expvar" "io" "net/http" "net/http/httptest" @@ -124,11 +127,128 @@ func TestProbeAccountProtocolHealthIgnoresContextAbort(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 40*time.Millisecond) defer cancel() - if err := app.probeAccountProtocolHealth(ctx, cfg, session); err != nil { + if err := app.probeAccountProtocolHealth(ctx, cfg, session, ""); err != nil { t.Fatalf("expected context abort probe to be ignored, got %v", err) } } +func TestProbeAccountProtocolHealthCachesProbeSuccessWithinTTL(t *testing.T) { + cfg := defaultConfig() + cfg.Dispatch.ProbeCacheTTLSeconds = 45 + state := &ServerState{DispatchProbeCache: newProbeCache()} + app := &App{ + State: state, + } + callCount := 0 + app.accountProtocolProbeOverride = func(ctx context.Context, cfg AppConfig, session SessionInfo) error { + callCount++ + return nil + } + + session := SessionInfo{ + UserEmail: "alice@example.com", + } + ctx := context.Background() + + for i := 0; i < 10; i++ { + if err := app.probeAccountProtocolHealth(ctx, cfg, session, "alice@example.com"); err != nil { + t.Fatalf("probe call %d failed: %v", i+1, err) + } + } + if callCount != 1 { + t.Fatalf("expected one upstream probe call within ttl window, got %d", callCount) + } +} + +func TestProbeAccountProtocolHealthReprobesAfterFailure(t *testing.T) { + cfg := defaultConfig() + cfg.Dispatch.ProbeCacheTTLSeconds = 45 + state := &ServerState{DispatchProbeCache: newProbeCache()} + app := &App{ + State: state, + } + callCount := 0 + app.accountProtocolProbeOverride = func(ctx context.Context, cfg AppConfig, session SessionInfo) error { + callCount++ + if callCount == 1 { + return errors.New("probe failed once") + } + return nil + } + + session := SessionInfo{ + UserEmail: "alice@example.com", + } + ctx := context.Background() + + if err := app.probeAccountProtocolHealth(ctx, cfg, session, "alice@example.com"); err == nil { + t.Fatalf("expected first probe failure") + } + if err := app.probeAccountProtocolHealth(ctx, cfg, session, "alice@example.com"); err != nil { + t.Fatalf("expected second probe to run and succeed, got %v", err) + } + if callCount != 2 { + t.Fatalf("expected second request to reprobe after failure, got callCount=%d", callCount) + } +} + +func TestRunPromptWithSessionIncrementsWreqClientMetric(t *testing.T) { + cfg := defaultConfig() + cfg.APIKey = "test-api-key" + cfg.Storage.SQLitePath = "" + state, err := newServerState(cfg) + if err != nil { + t.Fatalf("newServerState failed: %v", err) + } + defer func() { + _ = state.Close() + }() + app := &App{State: state} + session := SessionInfo{ + ClientVersion: "test-client-version", + UserID: "test-user", + SpaceID: "test-space", + Cookies: []ProbeCookie{{ + Name: "token_v2", + Value: "test-cookie", + }}, + } + beforeStandard := int64(0) + if v := transportClientNewTotalMetric.Get("standard"); v != nil { + beforeStandard = v.(*expvar.Int).Value() + } + beforeStreaming := int64(0) + if v := transportClientNewTotalMetric.Get("streaming"); v != nil { + beforeStreaming = v.(*expvar.Int).Value() + } + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err = app.runPromptWithSession(ctx, cfg, session, "", PromptRunRequest{Prompt: "hi"}, nil) + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context canceled, got %v", err) + } + _, err = app.runPromptWithSession(ctx, cfg, session, "", PromptRunRequest{Prompt: "hi"}, func(string) error { return nil }) + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context canceled for streaming run, got %v", err) + } + + afterStandard := int64(0) + if v := transportClientNewTotalMetric.Get("standard"); v != nil { + afterStandard = v.(*expvar.Int).Value() + } + afterStreaming := int64(0) + if v := transportClientNewTotalMetric.Get("streaming"); v != nil { + afterStreaming = v.(*expvar.Int).Value() + } + if afterStandard-beforeStandard < 1 { + t.Fatalf("expected standard metric increment, before=%d after=%d", beforeStandard, afterStandard) + } + if afterStreaming-beforeStreaming < 1 { + t.Fatalf("expected streaming metric increment, before=%d after=%d", beforeStreaming, afterStreaming) + } +} + func TestConsumeNDJSONStreamWithIdleCloseReturnsUpstreamErrorStep(t *testing.T) { threadID := "thread-error" messageID := "msg-error" @@ -147,6 +267,68 @@ func TestConsumeNDJSONStreamWithIdleCloseReturnsUpstreamErrorStep(t *testing.T) } } +func TestConsumeNDJSONStreamWithIdleCloseParsesFinalLineWithoutTrailingNewline(t *testing.T) { + threadID := "thread-error-no-newline" + messageID := "msg-error-no-newline" + recordMap := buildThreadErrorRecordMap(threadID, "test-space", messageID, "AI inference is not allowed.", "trust-rule-denied", "trace-error") + line, err := json.Marshal(map[string]any{ + "type": "record-map", + "recordMap": recordMap, + }) + if err != nil { + t.Fatalf("marshal ndjson line failed: %v", err) + } + + _, gotErr := consumeNDJSONStreamWithIdleClose(io.NopCloser(strings.NewReader(string(line))), threadID, InferenceStreamSink{}, 0) + if gotErr == nil || !strings.Contains(gotErr.Error(), "AI inference is not allowed") { + t.Fatalf("expected upstream error step without trailing newline, got %v", gotErr) + } +} + +func TestConsumeNDJSONStreamWithIdleCloseRejectsOversizedLine(t *testing.T) { + threadID := "thread-oversized-line" + oversizedLine := append(bytes.Repeat([]byte("a"), ndjsonMaxLineBytes+1), '\n') + + _, gotErr := consumeNDJSONStreamWithIdleClose(io.NopCloser(bytes.NewReader(oversizedLine)), threadID, InferenceStreamSink{}, 0) + if gotErr == nil { + t.Fatalf("expected oversized NDJSON line error, got nil") + } + if !errors.Is(gotErr, errNDJSONLineTooLarge) { + t.Fatalf("expected errNDJSONLineTooLarge, got %v", gotErr) + } +} + +func TestConsumeNDJSONStreamParsesFinalLineWithoutTrailingNewline(t *testing.T) { + threadID := "thread-error-no-newline-fallback" + messageID := "msg-error-no-newline-fallback" + recordMap := buildThreadErrorRecordMap(threadID, "test-space", messageID, "AI inference is not allowed.", "trust-rule-denied", "trace-error") + line, err := json.Marshal(map[string]any{ + "type": "record-map", + "recordMap": recordMap, + }) + if err != nil { + t.Fatalf("marshal ndjson line failed: %v", err) + } + + _, gotErr := consumeNDJSONStream(strings.NewReader(string(line)), threadID, InferenceStreamSink{}) + if gotErr == nil || !strings.Contains(gotErr.Error(), "AI inference is not allowed") { + t.Fatalf("expected upstream error step without trailing newline, got %v", gotErr) + } +} + +func TestConsumeNDJSONStreamRejectsOversizedLine(t *testing.T) { + threadID := "thread-oversized-line-fallback" + oversizedLine := append(bytes.Repeat([]byte("a"), ndjsonMaxLineBytes+1), '\n') + + _, gotErr := consumeNDJSONStream(bytes.NewReader(oversizedLine), threadID, InferenceStreamSink{}) + if gotErr == nil { + t.Fatalf("expected oversized NDJSON line error, got nil") + } + if !errors.Is(gotErr, errNDJSONLineTooLarge) { + t.Fatalf("expected errNDJSONLineTooLarge, got %v", gotErr) + } +} + func TestRunPromptReturnsUpstreamErrorStep(t *testing.T) { messageID := "msg-error" var recordMap map[string]any diff --git a/internal/app/notion_client_browser_transport.go b/internal/app/notion_client_browser_transport.go index caf2be0..2985586 100644 --- a/internal/app/notion_client_browser_transport.go +++ b/internal/app/notion_client_browser_transport.go @@ -1,23 +1,16 @@ package app import ( - "bytes" "context" - "encoding/json" - "errors" "fmt" "net/url" - "os" - "os/exec" - "path/filepath" "strings" "time" ) const ( - browserHelperCancelWaitDelay = 2 * time.Second - notionWreqDefaultBrowserProfile = "chrome_142" - notionWreqDefaultRequestTimeout = 120 * time.Second + notionTransportDefaultBrowserProfile = "chrome_142" + notionTransportDefaultRequestTimeout = 120 * time.Second ) type browserTransportRequest struct { @@ -33,27 +26,6 @@ type browserTransportRequest struct { IdleAfterAnswerMS int `json:"idle_after_answer_ms"` } -type browserTransportResponse struct { - Text string `json:"text"` - Status int `json:"status"` - ContentType string `json:"content_type"` -} - -type browserHelperUnavailableError struct { - Message string -} - -var ( - runBrowserFallback = runInferenceTranscriptInBrowserWithNodeWreq -) - -func (e *browserHelperUnavailableError) Error() string { - if e == nil { - return "" - } - return strings.TrimSpace(e.Message) -} - func detectInferenceStreamResponseFormat(body string) error { trimmed := strings.TrimSpace(strings.TrimPrefix(body, "\uFEFF")) if trimmed == "" { @@ -81,149 +53,7 @@ func runInferenceTranscriptInBrowser(ctx context.Context, client *NotionAIClient if len(client.Session.Cookies) == 0 { return "", fmt.Errorf("browser transport requires session cookies") } - return runBrowserFallback(ctx, client, payload) -} - -func runInferenceTranscriptInBrowserWithNodeWreq(ctx context.Context, client *NotionAIClient, payload map[string]any) (string, error) { - request, err := buildBrowserTransportRequest(client, payload) - if err != nil { - return "", err - } - return runHelperScript(ctx, "node", ".cjs", nodeWreqHelperScript(), request, browserHelperNodeEnv()) -} - -func runHelperScript(ctx context.Context, runtimeName string, extension string, script string, request browserTransportRequest, extraEnv []string) (string, error) { - requestPayload, err := json.Marshal(request) - if err != nil { - return "", err - } - stdoutBytes, err := executeHelperSubprocess(ctx, runtimeName, extension, script, requestPayload, extraEnv) - if err != nil { - return "", err - } - var response browserTransportResponse - if err := json.Unmarshal(stdoutBytes, &response); err != nil { - return "", fmt.Errorf("%s helper returned invalid json: %w", runtimeName, err) - } - if strings.TrimSpace(response.Text) == "" { - return "", fmt.Errorf("%s helper returned empty response (status=%d content_type=%q)", runtimeName, response.Status, response.ContentType) - } - if err := detectInferenceStreamResponseFormat(response.Text); err != nil { - return "", err - } - return response.Text, nil -} - -func executeHelperSubprocess(ctx context.Context, runtimeName string, extension string, script string, requestPayload []byte, extraEnv []string) ([]byte, error) { - if _, err := exec.LookPath(runtimeName); err != nil { - return nil, &browserHelperUnavailableError{Message: fmt.Sprintf("%s not found", runtimeName)} - } - scriptFile, err := os.CreateTemp("", "notion-browser-helper-*"+extension) - if err != nil { - return nil, err - } - scriptPath := scriptFile.Name() - defer os.Remove(scriptPath) - if _, err := scriptFile.WriteString(script); err != nil { - _ = scriptFile.Close() - return nil, err - } - if err := scriptFile.Close(); err != nil { - return nil, err - } - cmd := newBrowserHelperCommand(ctx, runtimeName, scriptPath, requestPayload, extraEnv) - var stdout bytes.Buffer - var stderr bytes.Buffer - cmd.Stdout = &stdout - cmd.Stderr = &stderr - if err := runBrowserHelperCommand(ctx, cmd); err != nil { - return nil, classifyBrowserHelperExecError(ctx, runtimeName, err, stderr.String()) - } - return stdout.Bytes(), nil -} -func newBrowserHelperCommand(ctx context.Context, runtimeName string, scriptPath string, requestPayload []byte, extraEnv []string) *exec.Cmd { - _ = ctx - cmd := exec.CommandContext(context.Background(), runtimeName, scriptPath) - cmd.Stdin = bytes.NewReader(requestPayload) - cmd.Env = append(os.Environ(), extraEnv...) - cmd.WaitDelay = browserHelperCancelWaitDelay - cmd.Cancel = func() error { - if cmd.Process == nil { - return nil - } - if err := cmd.Process.Kill(); err != nil && !errors.Is(err, os.ErrProcessDone) { - return err - } - return nil - } - configureBrowserHelperCommand(cmd) - return cmd -} - -func runBrowserHelperCommand(ctx context.Context, cmd *exec.Cmd) error { - if cmd == nil { - return fmt.Errorf("browser helper command is nil") - } - if err := cmd.Start(); err != nil { - return err - } - waitCh := make(chan error, 1) - go func() { - waitCh <- cmd.Wait() - }() - if ctx == nil { - return <-waitCh - } - select { - case err := <-waitCh: - return err - case <-ctx.Done(): - cancelErr := cancelBrowserHelperCommand(cmd) - waitErr := <-waitCh - if waitErr != nil { - return waitErr - } - if cancelErr != nil { - return cancelErr - } - return ctx.Err() - } -} - -func cancelBrowserHelperCommand(cmd *exec.Cmd) error { - if cmd == nil { - return nil - } - if cmd.Cancel != nil { - return cmd.Cancel() - } - if cmd.Process == nil { - return nil - } - if err := cmd.Process.Kill(); err != nil && !errors.Is(err, os.ErrProcessDone) { - return err - } - return nil -} - -func classifyBrowserHelperExecError(ctx context.Context, runtimeName string, runErr error, stderrText string) error { - if errors.Is(runErr, exec.ErrNotFound) { - return &browserHelperUnavailableError{Message: fmt.Sprintf("%s not found", runtimeName)} - } - if ctx != nil { - if ctxErr := ctx.Err(); ctxErr != nil { - return ctxErr - } - } - trimmed := strings.TrimSpace(stderrText) - lower := strings.ToLower(trimmed) - if strings.Contains(lower, "cannot find module") && strings.Contains(lower, "node-wreq") { - return &browserHelperUnavailableError{Message: fmt.Sprintf("%s missing node-wreq module", runtimeName)} - } - if trimmed == "" { - trimmed = runErr.Error() - } - return fmt.Errorf("%s helper failed: %s", runtimeName, trimmed) + return runInferenceTranscriptInBrowserWithSurf(ctx, client, payload) } func buildBrowserTransportRequest(client *NotionAIClient, payload map[string]any) (browserTransportRequest, error) { @@ -264,60 +94,13 @@ func buildBrowserTransportRequest(client *NotionAIClient, payload map[string]any Headers: headers, Payload: payload, Cookies: client.Session.Cookies, - BrowserProfile: notionWreqDefaultBrowserProfile, + BrowserProfile: notionTransportDefaultBrowserProfile, Proxy: proxyValue, - RequestTimeoutMS: int(notionWreqDefaultRequestTimeout / time.Millisecond), + RequestTimeoutMS: int(notionTransportDefaultRequestTimeout / time.Millisecond), IdleAfterAnswerMS: int(ndjsonIdleAfterAnswerTimeout / time.Millisecond), }, nil } -func browserHelperNodeEnv() []string { - candidates := []string{} - for _, candidate := range browserHelperNodeModuleCandidates() { - if strings.TrimSpace(candidate) == "" { - continue - } - if stat, err := os.Stat(candidate); err == nil && stat.IsDir() { - candidates = append(candidates, candidate) - } - } - if len(candidates) == 0 { - return nil - } - joined := strings.Join(candidates, string(os.PathListSeparator)) - if existing := strings.TrimSpace(os.Getenv("NODE_PATH")); existing != "" { - joined = existing + string(os.PathListSeparator) + joined - } - return []string{"NODE_PATH=" + joined} -} - -func browserHelperNodeModuleCandidates() []string { - candidates := []string{ - os.Getenv("NODE_PATH"), - "/opt/notion2api-helper/node_modules", - } - if cwd, err := os.Getwd(); err == nil { - candidates = append(candidates, filepath.Join(cwd, "node_modules")) - } - if executable, err := os.Executable(); err == nil { - candidates = append(candidates, filepath.Join(filepath.Dir(executable), "node_modules")) - } - return splitPathListCandidates(candidates) -} - -func splitPathListCandidates(values []string) []string { - candidates := []string{} - for _, value := range values { - for _, item := range filepath.SplitList(strings.TrimSpace(value)) { - if strings.TrimSpace(item) == "" { - continue - } - candidates = append(candidates, item) - } - } - return candidates -} - func (c *NotionAIClient) supportsBrowserRunInferenceFallback() bool { if c == nil { return false diff --git a/internal/app/notion_client_login_transport.go b/internal/app/notion_client_login_transport.go index d4e11a8..139cfab 100644 --- a/internal/app/notion_client_login_transport.go +++ b/internal/app/notion_client_login_transport.go @@ -2,16 +2,14 @@ package app import ( "context" - "encoding/json" "fmt" "net/http" "net/url" - "os/exec" "strings" "time" ) -type loginWreqRequest struct { +type loginTransportRequest struct { Method string `json:"method"` URL string `json:"url"` Headers map[string]string `json:"headers"` @@ -22,7 +20,7 @@ type loginWreqRequest struct { RequestTimeoutMS int `json:"request_timeout_ms"` } -type loginWreqResponse struct { +type loginTransportResponse struct { Status int `json:"status"` ContentType string `json:"content_type"` Headers map[string]string `json:"headers"` @@ -32,47 +30,38 @@ type loginWreqResponse struct { type loginHTTPSession struct { *http.Client - ProxyResolver *ProxyResolver - AccountEmail string - Timeout time.Duration - Upstream NotionUpstream + ProxyResolver *ProxyResolver + AccountEmail string + Timeout time.Duration + Upstream NotionUpstream + UseSurfHelperTransport bool } -func runLoginHelperRequest(ctx context.Context, request loginWreqRequest) (*loginWreqResponse, error) { - if _, err := exec.LookPath("node"); err != nil { - return nil, &browserHelperUnavailableError{Message: "node not found"} - } - requestPayload, err := json.Marshal(request) - if err != nil { - return nil, err - } - stdoutBytes, err := executeHelperSubprocess(ctx, "node", ".cjs", nodeWreqLoginHelperScript(), requestPayload, browserHelperNodeEnv()) - if err != nil { - return nil, err - } - var response loginWreqResponse - if err := json.Unmarshal(stdoutBytes, &response); err != nil { - return nil, fmt.Errorf("login helper returned invalid json: %w", err) - } - return &response, nil -} +var ( + loginTransportRunSurfRequest = runLoginHelperRequestWithSurf + loginTransportRunFallbackRequest = runLoginHelperRequestWithSurf +) -func loginWreqDoRequest(ctx context.Context, session *loginHTTPSession, method string, targetURL string, headers map[string]string, body []byte) (int, http.Header, []byte, error) { +func loginTransportDoRequest(ctx context.Context, session *loginHTTPSession, method string, targetURL string, headers map[string]string, body []byte) (int, http.Header, []byte, error) { if session == nil { return 0, nil, nil, fmt.Errorf("login session is nil") } - request := buildLoginWreqRequest(session, method, targetURL, headers, body) - resp, err := runLoginHelperRequest(ctx, request) + request := buildLoginTransportRequest(session, method, targetURL, headers, body) + var ( + resp *loginTransportResponse + err error + ) + resp, err = loginTransportRunSurfRequest(ctx, request) if err != nil { return 0, nil, nil, err } if session.Client != nil { - applyLoginWreqSetCookies(session.Jar, targetURL, resp.SetCookies) + applyLoginTransportSetCookies(session.Jar, targetURL, resp.SetCookies) } - return resp.Status, loginWreqHTTPHeader(resp.Headers), []byte(resp.Body), nil + return resp.Status, loginTransportHTTPHeader(resp.Headers), []byte(resp.Body), nil } -func buildLoginWreqRequest(session *loginHTTPSession, method string, targetURL string, headers map[string]string, body []byte) loginWreqRequest { +func buildLoginTransportRequest(session *loginHTTPSession, method string, targetURL string, headers map[string]string, body []byte) loginTransportRequest { cookies := []ProbeCookie{} if session != nil && session.Client != nil { cookies = probeCookiesFromJar(session.Jar, targetURL) @@ -99,19 +88,19 @@ func buildLoginWreqRequest(session *loginHTTPSession, method string, targetURL s } cleanHeaders[k] = v } - return loginWreqRequest{ + return loginTransportRequest{ Method: strings.ToUpper(strings.TrimSpace(method)), URL: targetURL, Headers: cleanHeaders, Body: string(body), Cookies: cookies, - BrowserProfile: notionWreqDefaultBrowserProfile, + BrowserProfile: notionTransportDefaultBrowserProfile, Proxy: proxyValue, RequestTimeoutMS: timeoutMS, } } -func applyLoginWreqSetCookies(jar http.CookieJar, targetURL string, setCookies []ProbeCookie) { +func applyLoginTransportSetCookies(jar http.CookieJar, targetURL string, setCookies []ProbeCookie) { if jar == nil || len(setCookies) == 0 { return } @@ -132,7 +121,7 @@ func applyLoginWreqSetCookies(jar http.CookieJar, targetURL string, setCookies [ } } -func loginWreqHTTPHeader(headers map[string]string) http.Header { +func loginTransportHTTPHeader(headers map[string]string) http.Header { out := http.Header{} for k, v := range headers { out.Set(k, v) diff --git a/internal/app/notion_client_protocol_test.go b/internal/app/notion_client_protocol_test.go index 6ee2765..83f7926 100644 --- a/internal/app/notion_client_protocol_test.go +++ b/internal/app/notion_client_protocol_test.go @@ -3,11 +3,24 @@ package app import ( "context" "encoding/json" + "expvar" "net/http" "net/http/httptest" "testing" ) +func resetNotionTransportCacheForTest() { + notionTransportCache.mu.Lock() + defer notionTransportCache.mu.Unlock() + for _, transport := range notionTransportCache.items { + if transport != nil { + transport.CloseIdleConnections() + } + } + notionTransportCache.items = map[notionHTTPTransportCacheKey]*http.Transport{} + notionHTTPTransportCacheMetric.Init() +} + func newProtocolTestClient(cfg AppConfig) *NotionAIClient { cfg.APIKey = "test-api-key" if cfg.UpstreamBaseURL == "" { @@ -217,3 +230,146 @@ func TestPostJSONResponseAddsResinAccountHeaderWhenEnabled(t *testing.T) { t.Fatalf("%s = %q, want %q", defaultResinAccountHeader, got, want) } } + +func TestNewNotionAIClientWithModeReusesTransportForSameConfigAndAccount(t *testing.T) { + resetNotionTransportCacheForTest() + cfg := defaultConfig() + cfg.APIKey = "test-api-key" + session := SessionInfo{ + ClientVersion: "test-client-version", + UserID: "test-user", + SpaceID: "test-space", + Cookies: []ProbeCookie{{ + Name: "token_v2", + Value: "test-cookie", + }}, + } + first := newNotionAIClientWithMode(session, cfg, "alice@example.com", false) + second := newNotionAIClientWithMode(session, cfg, "alice@example.com", false) + streaming := newNotionAIClientWithMode(session, cfg, "alice@example.com", true) + + if first.HTTPClient == nil || second.HTTPClient == nil || streaming.HTTPClient == nil { + t.Fatalf("expected HTTP clients to be initialized") + } + if first.HTTPClient.Transport == nil || second.HTTPClient.Transport == nil || streaming.HTTPClient.Transport == nil { + t.Fatalf("expected transports to be initialized") + } + if first.HTTPClient.Transport != second.HTTPClient.Transport { + t.Fatalf("expected transport reuse for same account/config") + } + if first.HTTPClient.Transport != streaming.HTTPClient.Transport { + t.Fatalf("expected streaming and standard clients to share transport cache") + } + if first.HTTPClient.Timeout <= 0 { + t.Fatalf("expected non-streaming timeout to be configured") + } + if streaming.HTTPClient.Timeout != 0 { + t.Fatalf("expected streaming client timeout to be disabled, got %s", streaming.HTTPClient.Timeout) + } +} + +func TestNewNotionAIClientWithModeSeparatesTransportWhenProxyPolicyDiffers(t *testing.T) { + resetNotionTransportCacheForTest() + cfg := defaultConfig() + cfg.APIKey = "test-api-key" + cfg.Accounts = []NotionAccount{ + { + Email: "alice@example.com", + ProxyMode: proxyModeHTTP, + ProxyURL: "http://127.0.0.1:18080", + }, + { + Email: "bob@example.com", + ProxyMode: proxyModeHTTP, + ProxyURL: "http://127.0.0.1:28080", + }, + } + session := SessionInfo{ + ClientVersion: "test-client-version", + UserID: "test-user", + SpaceID: "test-space", + Cookies: []ProbeCookie{{ + Name: "token_v2", + Value: "test-cookie", + }}, + } + alice := newNotionAIClientWithMode(session, cfg, "alice@example.com", false) + bob := newNotionAIClientWithMode(session, cfg, "bob@example.com", false) + + if alice.HTTPClient == nil || bob.HTTPClient == nil { + t.Fatalf("expected HTTP clients to be initialized") + } + if alice.HTTPClient.Transport == nil || bob.HTTPClient.Transport == nil { + t.Fatalf("expected transports to be initialized") + } + if alice.HTTPClient.Transport == bob.HTTPClient.Transport { + t.Fatalf("expected separate transports when account proxy policy differs") + } +} + +func TestCachedNotionHTTPTransportRecordsCacheMetrics(t *testing.T) { + resetNotionTransportCacheForTest() + cfg := defaultConfig() + cfg.APIKey = "test-api-key" + session := SessionInfo{ + ClientVersion: "test-client-version", + UserID: "test-user", + SpaceID: "test-space", + Cookies: []ProbeCookie{{ + Name: "token_v2", + Value: "test-cookie", + }}, + } + _ = newNotionAIClientWithMode(session, cfg, "alice@example.com", false) + _ = newNotionAIClientWithMode(session, cfg, "alice@example.com", false) + _ = newNotionAIClientWithMode(session, cfg, "alice@example.com", true) + + mustAtLeast := func(label string, wantMin int64) { + var got int64 + if v := notionHTTPTransportCacheMetric.Get(label); v != nil { + got = v.(*expvar.Int).Value() + } + if got < wantMin { + t.Fatalf("metric %s too small: got %d want >= %d", label, got, wantMin) + } + } + mustAtLeast("miss_new", 1) + mustAtLeast("hit_rlock", 1) +} + +func BenchmarkNewNotionAIClientWithModeTransportCache(b *testing.B) { + cfg := defaultConfig() + cfg.APIKey = "test-api-key" + session := SessionInfo{ + ClientVersion: "test-client-version", + UserID: "test-user", + SpaceID: "test-space", + Cookies: []ProbeCookie{{ + Name: "token_v2", + Value: "test-cookie", + }}, + } + + b.Run("warm_cache", func(b *testing.B) { + resetNotionTransportCacheForTest() + _ = newNotionAIClientWithMode(session, cfg, "alice@example.com", false) + b.ReportAllocs() + for i := 0; i < b.N; i++ { + client := newNotionAIClientWithMode(session, cfg, "alice@example.com", false) + if client == nil || client.HTTPClient == nil || client.HTTPClient.Transport == nil { + b.Fatalf("expected client with transport") + } + } + }) + + b.Run("cold_cache_reset_each_iter", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + resetNotionTransportCacheForTest() + client := newNotionAIClientWithMode(session, cfg, "alice@example.com", false) + if client == nil || client.HTTPClient == nil || client.HTTPClient.Transport == nil { + b.Fatalf("expected client with transport") + } + } + }) +} diff --git a/internal/app/notion_client_surf_transport.go b/internal/app/notion_client_surf_transport.go new file mode 100644 index 0000000..96da400 --- /dev/null +++ b/internal/app/notion_client_surf_transport.go @@ -0,0 +1,207 @@ +package app + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/enetx/g" + "github.com/enetx/surf" +) + +func newSurfStdClient(proxy string) (*http.Client, error) { + builder := surf.NewClient().Builder().Session().Impersonate().Chrome() + if strings.TrimSpace(proxy) != "" { + builder = builder.Proxy(g.String(proxy)) + } + clientResult := builder.Build() + if err := clientResult.Err(); err != nil { + return nil, err + } + return clientResult.Unwrap().Std(), nil +} + +func loadProbeCookiesIntoJar(jar http.CookieJar, target *url.URL, cookies []ProbeCookie) { + if jar == nil || target == nil || len(cookies) == 0 { + return + } + items := make([]*http.Cookie, 0, len(cookies)) + for _, c := range cookies { + name := strings.TrimSpace(c.Name) + if name == "" { + continue + } + items = append(items, &http.Cookie{ + Name: name, + Value: c.Value, + Path: "/", + }) + } + if len(items) > 0 { + jar.SetCookies(target, items) + } +} + +func runLoginHelperRequestWithSurf(ctx context.Context, request loginTransportRequest) (*loginTransportResponse, error) { + if ctx == nil { + ctx = context.Background() + } + stdClient, err := newSurfStdClient(request.Proxy) + if err != nil { + return nil, err + } + + timeout := time.Duration(request.RequestTimeoutMS) * time.Millisecond + if timeout < 30*time.Second { + timeout = 30 * time.Second + } + stdClient.Timeout = timeout + + parsedTargetURL, err := url.Parse(request.URL) + if err != nil { + return nil, err + } + loadProbeCookiesIntoJar(stdClient.Jar, parsedTargetURL, request.Cookies) + + var body io.Reader + if request.Body != "" { + body = bytes.NewBufferString(request.Body) + } + + method := strings.ToUpper(strings.TrimSpace(request.Method)) + if method == "" { + method = http.MethodGet + } + httpReq, err := http.NewRequestWithContext(ctx, method, parsedTargetURL.String(), body) + if err != nil { + return nil, err + } + for k, v := range request.Headers { + if strings.EqualFold(strings.TrimSpace(k), "cookie") { + continue + } + httpReq.Header.Set(k, v) + } + + resp, err := stdClient.Do(httpReq) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + out := &loginTransportResponse{ + Status: resp.StatusCode, + ContentType: resp.Header.Get("Content-Type"), + Headers: map[string]string{}, + Body: string(respBody), + SetCookies: []ProbeCookie{}, + } + for k, values := range resp.Header { + if strings.EqualFold(k, "set-cookie") || len(values) == 0 { + continue + } + out.Headers[strings.ToLower(k)] = values[len(values)-1] + } + + for _, c := range resp.Cookies() { + name := strings.TrimSpace(c.Name) + if name == "" { + continue + } + out.SetCookies = append(out.SetCookies, ProbeCookie{ + Name: name, + Value: c.Value, + }) + } + + // Preserve effective cookies after redirects by reading from the jar. + if stdClient.Jar != nil { + jarCookies := stdClient.Jar.Cookies(parsedTargetURL) + if len(jarCookies) > 0 { + out.SetCookies = out.SetCookies[:0] + for _, c := range jarCookies { + name := strings.TrimSpace(c.Name) + if name == "" { + continue + } + out.SetCookies = append(out.SetCookies, ProbeCookie{ + Name: name, + Value: c.Value, + }) + } + } + } + + return out, nil +} + +func runInferenceTranscriptInBrowserWithSurf(ctx context.Context, client *NotionAIClient, payload map[string]any) (string, error) { + if ctx == nil { + ctx = context.Background() + } + request, err := buildBrowserTransportRequest(client, payload) + if err != nil { + return "", err + } + stdClient, err := newSurfStdClient(request.Proxy) + if err != nil { + return "", err + } + + timeout := time.Duration(request.RequestTimeoutMS) * time.Millisecond + if timeout <= 0 { + timeout = notionTransportDefaultRequestTimeout + } + stdClient.Timeout = timeout + + parsedRunURL, err := url.Parse(request.RunURL) + if err != nil { + return "", err + } + loadProbeCookiesIntoJar(stdClient.Jar, parsedRunURL, request.Cookies) + + requestBody, err := json.Marshal(request.Payload) + if err != nil { + return "", err + } + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, parsedRunURL.String(), bytes.NewReader(requestBody)) + if err != nil { + return "", err + } + for k, v := range request.Headers { + if strings.EqualFold(strings.TrimSpace(k), "cookie") { + continue + } + httpReq.Header.Set(k, v) + } + + resp, err := stdClient.Do(httpReq) + if err != nil { + return "", err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return "", fmt.Errorf("browser fallback returned non-success status=%d content_type=%q", resp.StatusCode, resp.Header.Get("Content-Type")) + } + text := string(respBody) + if err := detectInferenceStreamResponseFormat(text); err != nil { + return "", err + } + return text, nil +} diff --git a/internal/app/notion_client_surf_transport_test.go b/internal/app/notion_client_surf_transport_test.go new file mode 100644 index 0000000..026588b --- /dev/null +++ b/internal/app/notion_client_surf_transport_test.go @@ -0,0 +1,251 @@ +package app + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/cookiejar" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestRunLoginHelperRequestWithSurf_MapsStatusHeadersBodyAndSetCookies(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Method; got != http.MethodPost { + t.Fatalf("method = %s, want POST", got) + } + if got := r.Header.Get("X-Test"); got != "ok" { + t.Fatalf("X-Test header = %q, want ok", got) + } + http.SetCookie(w, &http.Cookie{Name: "token_v2", Value: "new-value", Path: "/"}) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + _, _ = w.Write([]byte(`{"ok":true}`)) + })) + defer server.Close() + + resp, err := runLoginHelperRequestWithSurf(context.Background(), loginTransportRequest{ + Method: http.MethodPost, + URL: server.URL, + Headers: map[string]string{"X-Test": "ok"}, + Body: `{"hello":"world"}`, + RequestTimeoutMS: 30000, + }) + if err != nil { + t.Fatalf("runLoginHelperRequestWithSurf error: %v", err) + } + if resp.Status != http.StatusCreated { + t.Fatalf("status = %d, want %d", resp.Status, http.StatusCreated) + } + if !strings.Contains(strings.ToLower(resp.ContentType), "application/json") { + t.Fatalf("content_type = %q", resp.ContentType) + } + if strings.TrimSpace(resp.Body) != `{"ok":true}` { + t.Fatalf("body = %q", resp.Body) + } + if len(resp.SetCookies) == 0 || resp.SetCookies[0].Name != "token_v2" { + t.Fatalf("set_cookies = %#v", resp.SetCookies) + } +} + +func TestRunLoginHelperRequestWithSurf_ContextCanceled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := runLoginHelperRequestWithSurf(ctx, loginTransportRequest{ + Method: http.MethodGet, + URL: "https://example.com", + RequestTimeoutMS: 30000, + }) + if !errors.Is(err, context.Canceled) { + t.Fatalf("err = %v, want context.Canceled", err) + } +} + +func TestRunLoginHelperRequestWithSurf_PreservesRedirectSetCookies(t *testing.T) { + const cookieName = "redirect_token" + const cookieValue = "set-on-redirect-hop" + + mux := http.NewServeMux() + server := httptest.NewServer(mux) + defer server.Close() + + mux.HandleFunc("/start", func(w http.ResponseWriter, r *http.Request) { + http.SetCookie(w, &http.Cookie{Name: cookieName, Value: cookieValue, Path: "/"}) + http.Redirect(w, r, server.URL+"/final", http.StatusFound) + }) + mux.HandleFunc("/final", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("done")) + }) + + resp, err := runLoginHelperRequestWithSurf(context.Background(), loginTransportRequest{ + Method: http.MethodGet, + URL: server.URL + "/start", + RequestTimeoutMS: 30000, + }) + if err != nil { + t.Fatalf("runLoginHelperRequestWithSurf error: %v", err) + } + if resp.Status != http.StatusOK { + t.Fatalf("status = %d, want %d", resp.Status, http.StatusOK) + } + if got := probeCookieValue(resp.SetCookies, cookieName); got != cookieValue { + t.Fatalf("redirect cookie mismatch: got %q want %q, set_cookies=%#v", got, cookieValue, resp.SetCookies) + } +} + +func TestLoginTransportDoRequest_UsesSurfTransport(t *testing.T) { + origSurf := loginTransportRunSurfRequest + origFallback := loginTransportRunFallbackRequest + defer func() { + loginTransportRunSurfRequest = origSurf + loginTransportRunFallbackRequest = origFallback + }() + + surfHits := 0 + fallbackHits := 0 + loginTransportRunSurfRequest = func(_ context.Context, _ loginTransportRequest) (*loginTransportResponse, error) { + surfHits++ + return &loginTransportResponse{ + Status: http.StatusCreated, + Headers: map[string]string{"x-transport": "surf"}, + Body: "surf", + SetCookies: []ProbeCookie{{Name: "token_v2", Value: "surf"}}, + }, nil + } + loginTransportRunFallbackRequest = func(_ context.Context, _ loginTransportRequest) (*loginTransportResponse, error) { + fallbackHits++ + return &loginTransportResponse{ + Status: http.StatusAccepted, + Headers: map[string]string{"x-transport": "fallback"}, + Body: "fallback", + SetCookies: []ProbeCookie{{Name: "token_v2", Value: "fallback"}}, + }, nil + } + + jar, err := cookiejar.New(nil) + if err != nil { + t.Fatalf("cookiejar.New error: %v", err) + } + session := &loginHTTPSession{ + Client: &http.Client{Jar: jar}, + UseSurfHelperTransport: true, + ProxyResolver: nil, + AccountEmail: "tester@example.com", + Timeout: 30 * time.Second, + Upstream: NotionUpstream{}, + } + + targetURL := "https://example.com/login" + status, headers, body, err := loginTransportDoRequest(context.Background(), session, http.MethodGet, targetURL, map[string]string{"X-Test": "1"}, nil) + if err != nil { + t.Fatalf("loginTransportDoRequest error: %v", err) + } + if status != http.StatusCreated { + t.Fatalf("status = %d, want %d", status, http.StatusCreated) + } + if got := headers.Get("x-transport"); got != "surf" { + t.Fatalf("x-transport = %q, want %q", got, "surf") + } + if got := string(body); got != "surf" { + t.Fatalf("body = %q, want %q", got, "surf") + } + if surfHits != 1 { + t.Fatalf("surf branch hits mismatch: got %d want 1", surfHits) + } + if fallbackHits != 0 { + t.Fatalf("fallback branch should stay unused, got hits=%d", fallbackHits) + } + if got := probeCookieValue(probeCookiesFromJar(session.Jar, targetURL), "token_v2"); got != "surf" { + t.Fatalf("session jar token_v2 = %q, want %q", got, "surf") + } +} + +func TestLoginTransportDoRequest_SurfPreservesRedirectCookiesInSessionJar(t *testing.T) { + const cookieName = "redirect_token" + const cookieValue = "persisted" + + mux := http.NewServeMux() + server := httptest.NewServer(mux) + defer server.Close() + + mux.HandleFunc("/start", func(w http.ResponseWriter, r *http.Request) { + http.SetCookie(w, &http.Cookie{Name: cookieName, Value: cookieValue, Path: "/"}) + http.Redirect(w, r, server.URL+"/final", http.StatusFound) + }) + mux.HandleFunc("/final", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + }) + + jar, err := cookiejar.New(nil) + if err != nil { + t.Fatalf("cookiejar.New error: %v", err) + } + session := &loginHTTPSession{ + Client: &http.Client{Jar: jar}, + UseSurfHelperTransport: true, + Timeout: 30 * time.Second, + } + + targetURL := server.URL + "/start" + status, _, _, err := loginTransportDoRequest(context.Background(), session, http.MethodGet, targetURL, nil, nil) + if err != nil { + t.Fatalf("loginTransportDoRequest error: %v", err) + } + if status != http.StatusOK { + t.Fatalf("status = %d, want %d", status, http.StatusOK) + } + if got := probeCookieValue(probeCookiesFromJar(session.Jar, targetURL), cookieName); got != cookieValue { + t.Fatalf("session jar redirect cookie mismatch: got %q want %q", got, cookieValue) + } +} + +func TestRunInferenceTranscriptInBrowserWithSurf_ReturnsNDJSON(t *testing.T) { + line := `{"type":"agent-inference","id":"m1","finishedAt":"2026-05-03T00:00:00Z","value":[{"type":"text","content":"OK"}]}` + "\n" + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Method; got != http.MethodPost { + t.Fatalf("method = %s, want POST", got) + } + if got := strings.TrimSpace(r.Header.Get("Cookie")); got == "" { + t.Fatalf("expected cookie header to be present") + } + var payload map[string]any + if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { + t.Fatalf("decode payload failed: %v", err) + } + if got := strings.TrimSpace(stringValue(payload["threadId"])); got != "t1" { + t.Fatalf("threadId = %q, want t1", got) + } + w.Header().Set("Content-Type", "application/x-ndjson") + _, _ = w.Write([]byte(line)) + })) + defer server.Close() + + client := newBrowserFallbackTestClient(server.URL) + body, err := runInferenceTranscriptInBrowserWithSurf(context.Background(), client, map[string]any{"threadId": "t1"}) + if err != nil { + t.Fatalf("runInferenceTranscriptInBrowserWithSurf error: %v", err) + } + if body != line { + t.Fatalf("body mismatch: got %q want %q", body, line) + } +} + +func TestRunInferenceTranscriptInBrowserWithSurf_RejectsHTMLChallenge(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + _, _ = w.Write([]byte("cloudflare cookiePart challenge")) + })) + defer server.Close() + + client := newBrowserFallbackTestClient(server.URL) + _, err := runInferenceTranscriptInBrowserWithSurf(context.Background(), client, map[string]any{"threadId": "t1"}) + if err == nil || !strings.Contains(err.Error(), "challenge/html content") { + t.Fatalf("unexpected err: %v", err) + } +} diff --git a/internal/app/notion_client_wreq_transport.go b/internal/app/notion_client_wreq_transport.go deleted file mode 100644 index 7a5884c..0000000 --- a/internal/app/notion_client_wreq_transport.go +++ /dev/null @@ -1,241 +0,0 @@ -package app - -func nodeWreqHelperScript() string { - return `const fs = require('fs'); -const { fetch } = require('node-wreq'); - -(async () => { - const input = JSON.parse(fs.readFileSync(0, 'utf8')); - - const cookieMap = new Map(); - for (const item of input.cookies || []) { - const name = String((item && item.name) || '').trim(); - if (!name) continue; - cookieMap.set(name, String((item && item.value) || '')); - } - const cookieJar = { - getCookies() { - return [...cookieMap.entries()].map(([name, value]) => ({ name, value })); - }, - setCookie(cookie) { - const text = String(cookie || ''); - const semi = text.indexOf(';'); - const pair = semi === -1 ? text : text.slice(0, semi); - const eq = pair.indexOf('='); - if (eq <= 0) return; - const name = pair.slice(0, eq).trim(); - const value = pair.slice(eq + 1).trim(); - if (name) cookieMap.set(name, value); - }, - }; - - const headers = {}; - for (const [key, value] of Object.entries(input.headers || {})) { - if (key === undefined || key === null) continue; - if (String(key).toLowerCase() === 'cookie') continue; - headers[String(key)] = String(value == null ? '' : value); - } - - const fetchOptions = { - method: 'POST', - browser: input.browser_profile || 'chrome_142', - headers, - body: JSON.stringify(input.payload || {}), - cookieJar, - timeout: Math.max(Number(input.request_timeout_ms || 0), 30000), - throwHttpErrors: false, - }; - const proxy = String(input.proxy || '').trim(); - if (proxy) fetchOptions.proxy = proxy; - - const result = { status: 0, content_type: '', text: '' }; - let response; - try { - response = await fetch(input.run_url, fetchOptions); - } catch (err) { - process.stderr.write((err && err.stack ? err.stack : String(err)) + '\n'); - process.exit(2); - return; - } - - result.status = response.status; - result.content_type = response.headers.get('content-type') || ''; - const isNDJSON = String(result.content_type).toLowerCase().includes('application/x-ndjson'); - if (!isNDJSON) { - result.text = await response.text(); - process.stdout.write(JSON.stringify(result)); - return; - } - - const idleAfterAnswerMs = Math.max(Number(input.idle_after_answer_ms || 0), 0); - const readable = response.wreq && typeof response.wreq.readable === 'function' - ? response.wreq.readable() - : null; - if (!readable) { - result.text = await response.text(); - process.stdout.write(JSON.stringify(result)); - return; - } - - let pending = ''; - let sawAnswer = false; - let sawTerminal = false; - let settled = false; - let idleTimer = null; - - const markLineState = (line) => { - if (!line) return; - try { - const parsed = JSON.parse(line); - if (String(parsed.type || '').toLowerCase() !== 'agent-inference' || !Array.isArray(parsed.value)) return; - const hasVisibleText = parsed.value.some((entry) => { - const t = String((entry && entry.type) || '').toLowerCase(); - const c = String((entry && entry.content) || ''); - return t === 'text' && c.trim() !== ''; - }); - if (!hasVisibleText) return; - sawAnswer = true; - if (parsed.finishedAt != null) sawTerminal = true; - } catch (_) {} - }; - - await new Promise((resolve, reject) => { - const settle = () => { - if (settled) return; - settled = true; - if (idleTimer) { - clearTimeout(idleTimer); - idleTimer = null; - } - const remaining = pending.trim(); - if (remaining) markLineState(remaining); - try { readable.destroy(); } catch (_) {} - resolve(); - }; - const armIdle = () => { - if (idleTimer) { - clearTimeout(idleTimer); - idleTimer = null; - } - if (sawAnswer && idleAfterAnswerMs > 0) { - idleTimer = setTimeout(settle, idleAfterAnswerMs); - } - }; - readable.on('data', (chunk) => { - const text = Buffer.isBuffer(chunk) ? chunk.toString('utf8') : String(chunk); - result.text += text; - pending += text; - while (true) { - const newlineIndex = pending.indexOf('\n'); - if (newlineIndex === -1) break; - const line = pending.slice(0, newlineIndex).trim(); - pending = pending.slice(newlineIndex + 1); - markLineState(line); - if (sawTerminal) { - settle(); - return; - } - } - armIdle(); - }); - readable.on('end', settle); - readable.on('close', settle); - readable.on('error', (err) => { - if (settled) return; - settled = true; - if (idleTimer) clearTimeout(idleTimer); - reject(err); - }); - }); - - process.stdout.write(JSON.stringify(result)); -})().catch((error) => { - process.stderr.write((error && error.stack ? error.stack : String(error)) + '\n'); - process.exit(1); -}); -` -} - -func nodeWreqLoginHelperScript() string { - return `const fs = require('fs'); -const { fetch } = require('node-wreq'); - -(async () => { - const input = JSON.parse(fs.readFileSync(0, 'utf8')); - - const cookieMap = new Map(); - for (const item of input.cookies || []) { - const name = String((item && (item.name || item.Name)) || '').trim(); - if (!name) continue; - const rawValue = item && (item.value !== undefined ? item.value : item.Value); - cookieMap.set(name, String(rawValue == null ? '' : rawValue)); - } - const setCookieRecord = new Map(); - const cookieJar = { - getCookies() { - return [...cookieMap.entries()].map(([name, value]) => ({ name, value })); - }, - setCookie(cookie) { - const text = String(cookie || ''); - const semi = text.indexOf(';'); - const pair = semi === -1 ? text : text.slice(0, semi); - const eq = pair.indexOf('='); - if (eq <= 0) return; - const name = pair.slice(0, eq).trim(); - const value = pair.slice(eq + 1).trim(); - if (!name) return; - cookieMap.set(name, value); - setCookieRecord.set(name, value); - }, - }; - - const headers = {}; - for (const [key, value] of Object.entries(input.headers || {})) { - if (key === undefined || key === null) continue; - if (String(key).toLowerCase() === 'cookie') continue; - headers[String(key)] = String(value == null ? '' : value); - } - - const method = String(input.method || 'GET').toUpperCase(); - const fetchOptions = { - method, - browser: input.browser_profile || 'chrome_142', - headers, - cookieJar, - timeout: Math.max(Number(input.request_timeout_ms || 0), 30000), - throwHttpErrors: false, - }; - if (typeof input.body === 'string' && input.body.length > 0) { - fetchOptions.body = input.body; - } - const proxy = String(input.proxy || '').trim(); - if (proxy) fetchOptions.proxy = proxy; - - const result = { status: 0, content_type: '', headers: {}, body: '', set_cookies: [] }; - let response; - try { - response = await fetch(String(input.url || ''), fetchOptions); - } catch (err) { - process.stderr.write((err && err.stack ? err.stack : String(err)) + '\n'); - process.exit(2); - return; - } - - result.status = response.status; - if (response.headers && typeof response.headers.forEach === 'function') { - response.headers.forEach((value, key) => { - const lk = String(key).toLowerCase(); - if (lk === 'set-cookie') return; - result.headers[lk] = String(value); - }); - } - result.content_type = result.headers['content-type'] || ''; - result.body = await response.text(); - result.set_cookies = [...setCookieRecord.entries()].map(([name, value]) => ({ Name: name, Value: value })); - process.stdout.write(JSON.stringify(result)); -})().catch((error) => { - process.stderr.write((error && error.stack ? error.stack : String(error)) + '\n'); - process.exit(1); -}); -` -} diff --git a/internal/app/openai.go b/internal/app/openai.go index 27a3661..e2b5b4f 100644 --- a/internal/app/openai.go +++ b/internal/app/openai.go @@ -103,6 +103,13 @@ func normalizeChatInput(payload map[string]any) (NormalizedInput, error) { if !ok { return NormalizedInput{}, fmt.Errorf("messages must be an array") } + return normalizeChatInputFromParts(rawMessages, payload["attachments"]) +} + +func normalizeChatInputFromParts(rawMessages []any, attachmentsRaw any) (NormalizedInput, error) { + if rawMessages == nil { + return NormalizedInput{}, fmt.Errorf("messages must be an array") + } segments := make([]conversationPromptSegment, 0, len(rawMessages)) hiddenParts := make([]string, 0, len(rawMessages)) attachments := []InputAttachment{} @@ -125,7 +132,7 @@ func normalizeChatInput(payload map[string]any) (NormalizedInput, error) { hiddenParts = append(hiddenParts, hiddenSegments...) attachments = append(attachments, atts...) } - extra, err := extractAttachmentsFromAny(payload["attachments"]) + extra, err := extractAttachmentsFromAny(attachmentsRaw) if err != nil { return NormalizedInput{}, err } @@ -221,6 +228,10 @@ func buildConversationTranscriptPrompt(segments []conversationPromptSegment) str } func normalizeResponsesInput(payload map[string]any, previousResponse map[string]any) (NormalizedInput, error) { + return normalizeResponsesInputFromParts(payload["input"], payload["attachments"], previousResponse) +} + +func normalizeResponsesInputFromParts(rawInput any, attachmentsRaw any, previousResponse map[string]any) (NormalizedInput, error) { var ( prompt string hiddenPrompt string @@ -228,7 +239,7 @@ func normalizeResponsesInput(payload map[string]any, previousResponse map[string segments []conversationPromptSegment err error ) - switch x := payload["input"].(type) { + switch x := rawInput.(type) { case string: prompt = strings.TrimSpace(x) segments = appendConversationPromptSegment(segments, "user", prompt) @@ -242,10 +253,10 @@ func normalizeResponsesInput(payload map[string]any, previousResponse map[string return NormalizedInput{}, err } default: - prompt = strings.TrimSpace(flattenContent(payload["input"])) + prompt = strings.TrimSpace(flattenContent(rawInput)) segments = appendConversationPromptSegment(segments, "user", prompt) } - extra, err := extractAttachmentsFromAny(payload["attachments"]) + extra, err := extractAttachmentsFromAny(attachmentsRaw) if err != nil { return NormalizedInput{}, err } diff --git a/internal/app/openai_types.go b/internal/app/openai_types.go new file mode 100644 index 0000000..c0db59a --- /dev/null +++ b/internal/app/openai_types.go @@ -0,0 +1,339 @@ +package app + +import ( + "encoding/json" + "net/http" + "strings" +) + +type chatCompletionsRequestBody struct { + Model string `json:"model,omitempty"` + Stream bool `json:"stream,omitempty"` + ConversationID string `json:"conversation_id,omitempty"` + Conversation string `json:"conversation,omitempty"` + ThreadID string `json:"thread_id,omitempty"` + Thread string `json:"thread,omitempty"` + NotionThreadID string `json:"notion_thread_id,omitempty"` + AccountEmail string `json:"account_email,omitempty"` + NotionAccountEmail string `json:"notion_account_email,omitempty"` + UseWebSearch *bool `json:"use_web_search,omitempty"` + Metadata any `json:"metadata,omitempty"` + Tools any `json:"tools,omitempty"` + StreamOptions any `json:"stream_options,omitempty"` + Messages any `json:"messages,omitempty"` + Attachments any `json:"attachments,omitempty"` + StreamIncludeUsage *bool `json:"-"` + Type string `json:"type,omitempty"` + UserName string `json:"user_name,omitempty"` + CharName string `json:"char_name,omitempty"` + GroupNames []string `json:"group_names,omitempty"` + ContinuePrefill string `json:"continue_prefill,omitempty"` + ShowThoughts *bool `json:"show_thoughts,omitempty"` +} + +type responsesRequestBody struct { + Model string `json:"model,omitempty"` + Stream bool `json:"stream,omitempty"` + PreviousResponseID string `json:"previous_response_id,omitempty"` + ConversationID string `json:"conversation_id,omitempty"` + Conversation string `json:"conversation,omitempty"` + ThreadID string `json:"thread_id,omitempty"` + Thread string `json:"thread,omitempty"` + NotionThreadID string `json:"notion_thread_id,omitempty"` + AccountEmail string `json:"account_email,omitempty"` + NotionAccountEmail string `json:"notion_account_email,omitempty"` + UseWebSearch *bool `json:"use_web_search,omitempty"` + Metadata any `json:"metadata,omitempty"` + Tools any `json:"tools,omitempty"` + Input any `json:"input,omitempty"` + Attachments any `json:"attachments,omitempty"` +} + +func trimStringSlice(values []string) []string { + if len(values) == 0 { + return nil + } + out := make([]string, 0, len(values)) + for _, value := range values { + if clean := strings.TrimSpace(value); clean != "" { + out = append(out, clean) + } + } + if len(out) == 0 { + return nil + } + return out +} + +func normalizeTypedChatCompletionsRequestBody(body chatCompletionsRequestBody) chatCompletionsRequestBody { + body.Model = strings.TrimSpace(body.Model) + body.ConversationID = strings.TrimSpace(body.ConversationID) + body.Conversation = strings.TrimSpace(body.Conversation) + body.ThreadID = strings.TrimSpace(body.ThreadID) + body.Thread = strings.TrimSpace(body.Thread) + body.NotionThreadID = strings.TrimSpace(body.NotionThreadID) + body.AccountEmail = strings.TrimSpace(body.AccountEmail) + body.NotionAccountEmail = strings.TrimSpace(body.NotionAccountEmail) + body.Type = strings.TrimSpace(body.Type) + body.UserName = strings.TrimSpace(body.UserName) + body.CharName = strings.TrimSpace(body.CharName) + body.ContinuePrefill = strings.TrimSpace(body.ContinuePrefill) + body.GroupNames = trimStringSlice(body.GroupNames) + if value, ok := parseIncludeUsageFromStreamOptionsAny(body.StreamOptions); ok { + copyValue := value + body.StreamIncludeUsage = ©Value + } + return body +} + +func normalizeTypedResponsesRequestBody(body responsesRequestBody) responsesRequestBody { + body.Model = strings.TrimSpace(body.Model) + body.PreviousResponseID = strings.TrimSpace(body.PreviousResponseID) + body.ConversationID = strings.TrimSpace(body.ConversationID) + body.Conversation = strings.TrimSpace(body.Conversation) + body.ThreadID = strings.TrimSpace(body.ThreadID) + body.Thread = strings.TrimSpace(body.Thread) + body.NotionThreadID = strings.TrimSpace(body.NotionThreadID) + body.AccountEmail = strings.TrimSpace(body.AccountEmail) + body.NotionAccountEmail = strings.TrimSpace(body.NotionAccountEmail) + return body +} + +func requestedModelFromTyped(model string, fallback string) string { + modelID := strings.TrimSpace(model) + if modelID == "" { + return fallback + } + return modelID +} + +func extractChatCompletionsRequestBody(payload map[string]any) chatCompletionsRequestBody { + if payload == nil { + return chatCompletionsRequestBody{} + } + body := chatCompletionsRequestBody{ + Model: strings.TrimSpace(stringValue(payload["model"])), + ConversationID: strings.TrimSpace(stringValue(payload["conversation_id"])), + Conversation: strings.TrimSpace(stringValue(payload["conversation"])), + ThreadID: strings.TrimSpace(stringValue(payload["thread_id"])), + Thread: strings.TrimSpace(stringValue(payload["thread"])), + NotionThreadID: strings.TrimSpace(stringValue(payload["notion_thread_id"])), + AccountEmail: strings.TrimSpace(stringValue(payload["account_email"])), + NotionAccountEmail: strings.TrimSpace(stringValue(payload["notion_account_email"])), + Type: strings.TrimSpace(stringValue(payload["type"])), + UserName: strings.TrimSpace(stringValue(payload["user_name"])), + CharName: strings.TrimSpace(stringValue(payload["char_name"])), + ContinuePrefill: strings.TrimSpace(stringValue(payload["continue_prefill"])), + GroupNames: stringSliceValue(payload["group_names"]), + } + body.Stream, _ = payload["stream"].(bool) + if value, ok := parseBoolField(payload["use_web_search"]); ok { + copyValue := value + body.UseWebSearch = ©Value + } + if value, ok := parseBoolField(payload["show_thoughts"]); ok { + copyValue := value + body.ShowThoughts = ©Value + } + body.Metadata = payload["metadata"] + body.Tools = payload["tools"] + body.StreamOptions = payload["stream_options"] + body.Messages = payload["messages"] + body.Attachments = payload["attachments"] + if value, ok := parseIncludeUsageFromStreamOptionsAny(body.StreamOptions); ok { + copyValue := value + body.StreamIncludeUsage = ©Value + } + return normalizeTypedChatCompletionsRequestBody(body) +} + +func extractResponsesRequestBody(payload map[string]any) responsesRequestBody { + if payload == nil { + return responsesRequestBody{} + } + body := responsesRequestBody{ + Model: strings.TrimSpace(stringValue(payload["model"])), + PreviousResponseID: strings.TrimSpace(stringValue(payload["previous_response_id"])), + ConversationID: strings.TrimSpace(stringValue(payload["conversation_id"])), + Conversation: strings.TrimSpace(stringValue(payload["conversation"])), + ThreadID: strings.TrimSpace(stringValue(payload["thread_id"])), + Thread: strings.TrimSpace(stringValue(payload["thread"])), + NotionThreadID: strings.TrimSpace(stringValue(payload["notion_thread_id"])), + AccountEmail: strings.TrimSpace(stringValue(payload["account_email"])), + NotionAccountEmail: strings.TrimSpace(stringValue(payload["notion_account_email"])), + } + body.Stream, _ = payload["stream"].(bool) + if value, ok := parseBoolField(payload["use_web_search"]); ok { + copyValue := value + body.UseWebSearch = ©Value + } + body.Metadata = payload["metadata"] + body.Tools = payload["tools"] + body.Input = payload["input"] + body.Attachments = payload["attachments"] + return normalizeTypedResponsesRequestBody(body) +} + +func requestedConversationIDFromTyped(r *http.Request, conversationID string, conversation string, metadata any) string { + if fromHeader := firstRequestValue(r, "X-Conversation-ID", "X-Notion-Conversation-ID"); fromHeader != "" { + return fromHeader + } + if value := strings.TrimSpace(conversationID); value != "" { + return value + } + if value := strings.TrimSpace(conversation); value != "" { + return value + } + return parseStringFieldFromMetadataAny(metadata, "conversation_id", "notion_conversation_id") +} + +func requestedThreadIDFromTyped(r *http.Request, threadID string, thread string, notionThreadID string, metadata any) string { + if fromHeader := firstRequestValue(r, "X-Thread-ID", "X-Notion-Thread-ID"); fromHeader != "" { + return fromHeader + } + for _, value := range []string{threadID, thread, notionThreadID} { + if clean := strings.TrimSpace(value); clean != "" { + return clean + } + } + return parseStringFieldFromMetadataAny(metadata, "thread_id", "notion_thread_id") +} + +func requestedAccountEmailFromTyped(r *http.Request, accountEmail string, notionAccountEmail string, metadata any) string { + if fromHeader := firstRequestValue(r, "X-Account-Email", "X-Notion-Account-Email"); fromHeader != "" { + return fromHeader + } + for _, value := range []string{accountEmail, notionAccountEmail} { + if clean := strings.TrimSpace(value); clean != "" { + return clean + } + } + return parseStringFieldFromMetadataAny(metadata, "account_email", "notion_account_email") +} + +func requestedWebSearchFromTyped(useWebSearch *bool, metadata any, tools any, fallback bool) bool { + if useWebSearch != nil { + return *useWebSearch + } + if value, ok := parseWebSearchFromMetadataAny(metadata); ok { + return value + } + if value, ok := parseWebSearchFromToolsAny(tools); ok { + return value + } + return fallback +} + +func parseWebSearchFromMetadataAny(raw any) (bool, bool) { + meta := decodeJSONObjectAny(raw) + if meta == nil { + return false, false + } + for _, key := range []string{"use_web_search", "notion_use_web_search"} { + if value, ok := meta[key]; ok { + if parsed, parsedOK := parseBoolField(value); parsedOK { + return parsed, true + } + } + } + return false, false +} + +func parseStringFieldFromMetadataAny(raw any, keys ...string) string { + meta := decodeJSONObjectAny(raw) + if meta == nil { + return "" + } + for _, key := range keys { + if value := strings.TrimSpace(stringValue(meta[key])); value != "" { + return value + } + } + return "" +} + +func decodeJSONObjectAny(raw any) map[string]any { + if raw == nil { + return nil + } + if meta := mapValue(raw); meta != nil { + return meta + } + var decoded map[string]any + switch value := raw.(type) { + case json.RawMessage: + if err := json.Unmarshal(value, &decoded); err == nil { + return decoded + } + case []byte: + if err := json.Unmarshal(value, &decoded); err == nil { + return decoded + } + case string: + if err := json.Unmarshal([]byte(value), &decoded); err == nil { + return decoded + } + } + return nil +} + +func parseWebSearchFromToolsAny(raw any) (bool, bool) { + if raw == nil { + return false, false + } + toolItems := sliceValue(raw) + if len(toolItems) == 0 { + switch value := raw.(type) { + case json.RawMessage: + var decoded []map[string]any + if err := json.Unmarshal(value, &decoded); err == nil { + toolItems = sliceValue(decoded) + } + case []byte: + var decoded []map[string]any + if err := json.Unmarshal(value, &decoded); err == nil { + toolItems = sliceValue(decoded) + } + case string: + var decoded []map[string]any + if err := json.Unmarshal([]byte(value), &decoded); err == nil { + toolItems = sliceValue(decoded) + } + } + } + for _, item := range toolItems { + tool := mapValue(item) + if tool == nil { + continue + } + toolType := strings.TrimSpace(stringValue(tool["type"])) + if strings.Contains(toolType, "web_search") { + return true, true + } + } + return false, false +} + +func parseIncludeUsageFromStreamOptionsAny(raw any) (bool, bool) { + options := decodeJSONObjectAny(raw) + if options == nil { + return false, false + } + return parseBoolField(options["include_usage"]) +} + +func (body chatCompletionsRequestBody) likelySillyTavernByEnvelope() bool { + if strings.TrimSpace(body.Type) != "" { + return true + } + if strings.TrimSpace(body.UserName) != "" && strings.TrimSpace(body.CharName) != "" { + return true + } + if len(body.GroupNames) > 0 { + return true + } + if strings.TrimSpace(body.ContinuePrefill) != "" { + return true + } + return body.ShowThoughts != nil +} diff --git a/internal/app/prompt_guard.go b/internal/app/prompt_guard.go index 0f8ea74..88d0576 100644 --- a/internal/app/prompt_guard.go +++ b/internal/app/prompt_guard.go @@ -96,6 +96,23 @@ func promptGuardProfileChain(cfg AppConfig, hasTools bool) []promptProfile { return chain[:limit] } +func buildPromptGuardAllRetryPrefixes(promptCfg PromptConfig) []string { + totalCap := len(defaultPromptCodingRetryPrefixes()) + + len(defaultPromptGeneralRetryPrefixes()) + + len(defaultPromptDirectAnswerRetryPrefixes()) + + len(promptCfg.CodingRetryPrefixes) + + len(promptCfg.GeneralRetryPrefixes) + + len(promptCfg.DirectAnswerRetryPrefixes) + out := make([]string, 0, totalCap) + out = append(out, defaultPromptCodingRetryPrefixes()...) + out = append(out, defaultPromptGeneralRetryPrefixes()...) + out = append(out, defaultPromptDirectAnswerRetryPrefixes()...) + out = append(out, promptCfg.CodingRetryPrefixes...) + out = append(out, promptCfg.GeneralRetryPrefixes...) + out = append(out, promptCfg.DirectAnswerRetryPrefixes...) + return out +} + func resolvePromptGuardProfile(cfg AppConfig, request PromptRunRequest) (promptProfile, int) { chain := promptGuardProfileChain(cfg, false) if len(chain) == 0 { @@ -175,11 +192,10 @@ func promptGuardPrepareRequest(cfg AppConfig, request PromptRunRequest) PromptRu func promptGuardStripRetryPrefixes(cfg AppConfig, text string) string { current := text - all := append(append([]string{}, defaultPromptCodingRetryPrefixes()...), defaultPromptGeneralRetryPrefixes()...) - all = append(all, defaultPromptDirectAnswerRetryPrefixes()...) - all = append(all, cfg.Prompt.CodingRetryPrefixes...) - all = append(all, cfg.Prompt.GeneralRetryPrefixes...) - all = append(all, cfg.Prompt.DirectAnswerRetryPrefixes...) + all := cfg.Prompt.precomputedAllRetryPrefixes + if len(all) == 0 { + all = buildPromptGuardAllRetryPrefixes(cfg.Prompt) + } matched := true for matched { matched = false @@ -193,13 +209,14 @@ func promptGuardStripRetryPrefixes(cfg AppConfig, text string) string { return current } +var promptGuardCodingRequestPatterns = []*regexp.Regexp{ + regexp.MustCompile(`(?i)\b(code|coding|program|function|class|bug|debug|refactor|api|sdk|javascript|typescript|python|golang|rust|docker|sql|bash|shell|json|yaml|repository|repo|frontend|backend|server|client)\b`), + regexp.MustCompile(`代码|编程|开发|函数|脚本|调试|报错|异常|接口|部署|构建|数据库|仓库|前端|后端|服务端|客户端|测试|日志`), + regexp.MustCompile("```"), +} + func promptGuardLooksLikeCodingRequest(text string) bool { - patterns := []*regexp.Regexp{ - regexp.MustCompile(`(?i)\b(code|coding|program|function|class|bug|debug|refactor|api|sdk|javascript|typescript|python|golang|rust|docker|sql|bash|shell|json|yaml|repository|repo|frontend|backend|server|client)\b`), - regexp.MustCompile(`代码|编程|开发|函数|脚本|调试|报错|异常|接口|部署|构建|数据库|仓库|前端|后端|服务端|客户端|测试|日志`), - regexp.MustCompile("```"), - } - for _, pattern := range patterns { + for _, pattern := range promptGuardCodingRequestPatterns { if pattern.MatchString(text) { return true } diff --git a/internal/app/request_dispatch.go b/internal/app/request_dispatch.go index 68e1cb3..b8f152f 100644 --- a/internal/app/request_dispatch.go +++ b/internal/app/request_dispatch.go @@ -3,9 +3,11 @@ package app import ( "context" "errors" + "expvar" "fmt" "net/http" "strings" + "sync" "time" ) @@ -16,6 +18,88 @@ const ( var errDispatchCapacityExceeded = errors.New("dispatch capacity exceeded") +var transportClientNewTotalMetric = expvar.NewMap("notion2api_transport_client_new_total") + +type probeCacheEntry struct { + lastChecked time.Time + lastOK bool +} + +type probeCache struct { + mu sync.Mutex + entries map[string]probeCacheEntry +} + +func newProbeCache() *probeCache { + return &probeCache{ + entries: map[string]probeCacheEntry{}, + } +} + +func (c *probeCache) shouldProbe(accountKey string, ttl time.Duration, now time.Time) bool { + if c == nil { + return true + } + if strings.TrimSpace(accountKey) == "" { + return true + } + if ttl <= 0 { + return true + } + c.mu.Lock() + defer c.mu.Unlock() + if c.entries == nil { + c.entries = map[string]probeCacheEntry{} + return true + } + entry, ok := c.entries[accountKey] + if !ok { + return true + } + if !entry.lastOK { + return true + } + return now.Sub(entry.lastChecked) >= ttl +} + +func (c *probeCache) markSuccess(accountKey string, now time.Time) { + if c == nil { + return + } + accountKey = strings.TrimSpace(accountKey) + if accountKey == "" { + return + } + c.mu.Lock() + defer c.mu.Unlock() + if c.entries == nil { + c.entries = map[string]probeCacheEntry{} + } + c.entries[accountKey] = probeCacheEntry{lastChecked: now, lastOK: true} +} + +func (c *probeCache) markFailure(accountKey string) { + if c == nil { + return + } + accountKey = strings.TrimSpace(accountKey) + if accountKey == "" { + return + } + c.mu.Lock() + defer c.mu.Unlock() + delete(c.entries, accountKey) +} + +func (c *probeCache) invalidateAll() { + if c == nil { + return + } + c.mu.Lock() + defer c.mu.Unlock() + c.entries = map[string]probeCacheEntry{} +} + func requestTimeout(cfg AppConfig) time.Duration { return time.Duration(maxInt(cfg.TimeoutSec, 10)) * time.Second } @@ -40,7 +124,7 @@ func mergeDispatchCandidates(preferred *NotionAccount, candidates []NotionAccoun out := make([]NotionAccount, 0, len(candidates)+1) seen := map[string]struct{}{} appendCandidate := func(account NotionAccount) { - key := canonicalEmailKey(account.Email) + key := getAccountEmailKey(account) if key == "" { return } @@ -60,8 +144,19 @@ func mergeDispatchCandidates(preferred *NotionAccount, candidates []NotionAccoun } func resolveDispatchCandidates(cfg AppConfig, request PromptRunRequest, now time.Time) ([]NotionAccount, error) { + poolCandidates := buildDispatchCandidateOrder(cfg, now) + return resolveDispatchCandidatesWithPool(cfg, poolCandidates, request, now) +} + +func resolveDispatchCandidatesFromSnapshot(bundle *snapshotBundle, request PromptRunRequest, now time.Time) ([]NotionAccount, error) { + if bundle == nil { + return nil, noEligibleAccountsError() + } + return resolveDispatchCandidatesWithPool(bundle.Config, pickDispatchCandidatesFromSnapshot(bundle, now), request, now) +} + +func resolveDispatchCandidatesWithPool(cfg AppConfig, poolCandidates []NotionAccount, request PromptRunRequest, now time.Time) ([]NotionAccount, error) { pinnedEmail := strings.TrimSpace(request.PinnedAccountEmail) - poolCandidates := pickDispatchCandidates(cfg, now) if pinnedEmail == "" { if len(poolCandidates) == 0 { return nil, noEligibleAccountsError() @@ -116,6 +211,14 @@ func dispatchProtocolProbeTimeout(cfg AppConfig) time.Duration { return time.Duration(seconds) * time.Second } +func dispatchProbeCacheTTL(cfg AppConfig) time.Duration { + seconds := cfg.Dispatch.ProbeCacheTTLSeconds + if seconds <= 0 { + return 0 + } + return time.Duration(seconds) * time.Second +} + func isDispatchContextAbort(ctx context.Context, err error) bool { if err == nil { return false @@ -126,17 +229,82 @@ func isDispatchContextAbort(ctx context.Context, err error) bool { return ctx != nil && ctx.Err() != nil } -func (a *App) probeAccountProtocolHealth(ctx context.Context, cfg AppConfig, session SessionInfo) error { +func (a *App) shouldProbeAccountProtocolHealth(accountKey string, ttl time.Duration, now time.Time) bool { + if a == nil { + return true + } + if a.State == nil || a.State.DispatchProbeCache == nil { + return true + } + return a.State.DispatchProbeCache.shouldProbe(accountKey, ttl, now) +} + +func (a *App) markAccountProtocolProbeSuccess(accountKey string, now time.Time) { + if a == nil { + return + } + if a.State == nil || a.State.DispatchProbeCache == nil { + return + } + a.State.DispatchProbeCache.markSuccess(accountKey, now) +} + +func (a *App) markAccountProtocolProbeFailure(accountKey string) { + if a == nil { + return + } + if a.State == nil || a.State.DispatchProbeCache == nil { + return + } + a.State.DispatchProbeCache.markFailure(accountKey) +} + +func (a *App) invalidateDispatchProbeCache() { + if a == nil { + return + } + if a.State == nil || a.State.DispatchProbeCache == nil { + return + } + a.State.DispatchProbeCache.invalidateAll() +} + +func (a *App) probeAccountProtocolHealth(ctx context.Context, cfg AppConfig, session SessionInfo, accountEmail string) error { + accountKey := canonicalEmailKey(accountEmail) + if accountKey == "" { + accountKey = canonicalEmailKey(session.UserEmail) + } + now := time.Now() + ttl := dispatchProbeCacheTTL(cfg) + if !a.shouldProbeAccountProtocolHealth(accountKey, ttl, now) { + return nil + } if a.accountProtocolProbeOverride != nil { - return a.accountProtocolProbeOverride(ctx, cfg, session) + err := a.accountProtocolProbeOverride(ctx, cfg, session) + if err == nil { + a.markAccountProtocolProbeSuccess(accountKey, now) + return nil + } + if isDispatchContextAbort(ctx, err) { + a.markAccountProtocolProbeSuccess(accountKey, now) + return nil + } + a.markAccountProtocolProbeFailure(accountKey) + return err } probeCtx, cancel := context.WithTimeout(ctx, dispatchProtocolProbeTimeout(cfg)) defer cancel() client := newNotionAIClient(session, cfg, "") _, err := client.listInferenceTranscripts(probeCtx) if isDispatchContextAbort(probeCtx, err) { + a.markAccountProtocolProbeSuccess(accountKey, now) return nil } + if err != nil { + a.markAccountProtocolProbeFailure(accountKey) + return err + } + a.markAccountProtocolProbeSuccess(accountKey, now) return err } @@ -145,7 +313,7 @@ func (a *App) loadReadyDispatchSession(ctx context.Context, cfg AppConfig, accou if err != nil { return SessionInfo{}, err } - if err := a.probeAccountProtocolHealth(ctx, cfg, session); err != nil { + if err := a.probeAccountProtocolHealth(ctx, cfg, session, account.Email); err != nil { return SessionInfo{}, err } return session, nil @@ -183,7 +351,7 @@ func (a *App) runPromptActiveFallback(r *http.Request, request PromptRunRequest, if err != nil { return InferenceResult{}, err } - if err := a.probeAccountProtocolHealth(ctx, cfg, session); err != nil { + if err := a.probeAccountProtocolHealth(ctx, cfg, session, ""); err != nil { return InferenceResult{}, err } @@ -204,9 +372,10 @@ func (a *App) runPromptActiveFallback(r *http.Request, request PromptRunRequest, } if cfg.ResolveSessionRefresh().RetryOnAuthError && isSessionRetryableError(err) && !emittedAny { if refreshErr := a.State.RefreshSession(ctx, "prompt_retry_fallback"); refreshErr == nil { + a.invalidateDispatchProbeCache() _, refreshed, _ := a.State.Snapshot() if strings.TrimSpace(refreshed.UserID) != "" && strings.TrimSpace(refreshed.SpaceID) != "" && len(refreshed.Cookies) > 0 { - if probeErr := a.probeAccountProtocolHealth(ctx, cfg, refreshed); probeErr != nil { + if probeErr := a.probeAccountProtocolHealth(ctx, cfg, refreshed, ""); probeErr != nil { return InferenceResult{}, probeErr } return a.runPromptWithSession(ctx, cfg, refreshed, "", request, wrappedDelta) @@ -226,7 +395,7 @@ func (a *App) runPromptActiveFallbackWithSink(r *http.Request, request PromptRun if err != nil { return InferenceResult{}, err } - if err := a.probeAccountProtocolHealth(ctx, cfg, session); err != nil { + if err := a.probeAccountProtocolHealth(ctx, cfg, session, ""); err != nil { return InferenceResult{}, err } @@ -261,9 +430,10 @@ func (a *App) runPromptActiveFallbackWithSink(r *http.Request, request PromptRun } if cfg.ResolveSessionRefresh().RetryOnAuthError && isSessionRetryableError(err) && !emittedAny { if refreshErr := a.State.RefreshSession(ctx, "prompt_retry_fallback"); refreshErr == nil { + a.invalidateDispatchProbeCache() _, refreshed, _ := a.State.Snapshot() if strings.TrimSpace(refreshed.UserID) != "" && strings.TrimSpace(refreshed.SpaceID) != "" && len(refreshed.Cookies) > 0 { - if probeErr := a.probeAccountProtocolHealth(ctx, cfg, refreshed); probeErr != nil { + if probeErr := a.probeAccountProtocolHealth(ctx, cfg, refreshed, ""); probeErr != nil { return InferenceResult{}, probeErr } return a.runPromptWithSessionWithSink(ctx, cfg, refreshed, "", request, InferenceStreamSink{ @@ -292,7 +462,17 @@ func (a *App) runPromptWithAccountPool(r *http.Request, request PromptRunRequest defer cancel() now := time.Now() - candidates, err := resolveDispatchCandidates(cfg, request, now) + var candidates []NotionAccount + var err error + if a != nil && a.State != nil { + if snap := a.State.snap.Load(); snap != nil { + candidates, err = resolveDispatchCandidatesFromSnapshot(snap, request, now) + } else { + candidates, err = resolveDispatchCandidates(cfg, request, now) + } + } else { + candidates, err = resolveDispatchCandidates(cfg, request, now) + } if err != nil { return InferenceResult{}, err } @@ -359,6 +539,7 @@ func (a *App) runPromptWithAccountPool(r *http.Request, request PromptRunRequest refreshedCfg, refreshErr := a.State.tryRefreshAccount(ctx, cfg, account) if refreshErr == nil { if saveErr := a.State.SaveAndApply(refreshedCfg); saveErr == nil { + a.invalidateDispatchProbeCache() cfg = refreshedCfg refreshedAccount, _, ok := cfg.FindAccount(account.Email) if ok { @@ -445,7 +626,17 @@ func (a *App) runPromptWithAccountPoolWithSink(r *http.Request, request PromptRu defer cancel() now := time.Now() - candidates, err := resolveDispatchCandidates(cfg, request, now) + var candidates []NotionAccount + var err error + if a != nil && a.State != nil { + if snap := a.State.snap.Load(); snap != nil { + candidates, err = resolveDispatchCandidatesFromSnapshot(snap, request, now) + } else { + candidates, err = resolveDispatchCandidates(cfg, request, now) + } + } else { + candidates, err = resolveDispatchCandidates(cfg, request, now) + } if err != nil { return InferenceResult{}, err } @@ -526,6 +717,7 @@ func (a *App) runPromptWithAccountPoolWithSink(r *http.Request, request PromptRu refreshedCfg, refreshErr := a.State.tryRefreshAccount(ctx, cfg, account) if refreshErr == nil { if saveErr := a.State.SaveAndApply(refreshedCfg); saveErr == nil { + a.invalidateDispatchProbeCache() cfg = refreshedCfg if refreshedAccount, _, ok := cfg.FindAccount(account.Email); ok { refreshedSession, loadErr := a.loadReadyDispatchSession(ctx, cfg, refreshedAccount) diff --git a/internal/app/response_store.go b/internal/app/response_store.go new file mode 100644 index 0000000..0209cd7 --- /dev/null +++ b/internal/app/response_store.go @@ -0,0 +1,217 @@ +package app + +import ( + "container/heap" + "strings" + "time" +) + +const responseStoreCleanupInterval = 30 * time.Second + +type responseExpiryEntry struct { + responseID string + createdAt time.Time +} + +type responseExpiryHeap []responseExpiryEntry + +func (h responseExpiryHeap) Len() int { + return len(h) +} + +func (h responseExpiryHeap) Less(i, j int) bool { + return h[i].createdAt.Before(h[j].createdAt) +} + +func (h responseExpiryHeap) Swap(i, j int) { + h[i], h[j] = h[j], h[i] +} + +func (h *responseExpiryHeap) Push(x any) { + entry, _ := x.(responseExpiryEntry) + *h = append(*h, entry) +} + +func (h *responseExpiryHeap) Pop() any { + if h == nil || len(*h) == 0 { + return responseExpiryEntry{} + } + old := *h + last := old[len(old)-1] + *h = old[:len(old)-1] + return last +} + +type responseStore struct { + ttl time.Duration + items map[string]StoredResponse + expirations responseExpiryHeap +} + +var testHookResponseStorePrunePop func() + +func normalizeResponseStoreTTL(ttl time.Duration) time.Duration { + if ttl <= 0 { + return time.Second + } + return ttl +} + +func newResponseStore(ttl time.Duration) *responseStore { + store := &responseStore{ + ttl: normalizeResponseStoreTTL(ttl), + items: map[string]StoredResponse{}, + expirations: responseExpiryHeap{}, + } + heap.Init(&store.expirations) + return store +} + +func (s *responseStore) setTTL(ttl time.Duration) { + if s == nil { + return + } + s.ttl = normalizeResponseStoreTTL(ttl) +} + +func (s *responseStore) ensureInitialized() { + if s == nil { + return + } + if s.items == nil { + s.items = map[string]StoredResponse{} + } + if s.expirations == nil { + s.expirations = responseExpiryHeap{} + heap.Init(&s.expirations) + } +} + +func (s *responseStore) save(responseID string, record StoredResponse, now time.Time) { + if s == nil { + return + } + responseID = strings.TrimSpace(responseID) + if responseID == "" { + return + } + s.ensureInitialized() + now = now.UTC() + s.pruneExpired(now) + + createdAt := record.CreatedAt.UTC() + if createdAt.IsZero() { + createdAt = now + } + record.CreatedAt = createdAt + record.ConversationID = strings.TrimSpace(record.ConversationID) + record.ThreadID = strings.TrimSpace(record.ThreadID) + record.AccountEmail = strings.TrimSpace(record.AccountEmail) + + s.items[responseID] = record + heap.Push(&s.expirations, responseExpiryEntry{ + responseID: responseID, + createdAt: createdAt, + }) +} + +func (s *responseStore) get(responseID string, now time.Time) (StoredResponse, bool) { + if s == nil { + return StoredResponse{}, false + } + responseID = strings.TrimSpace(responseID) + if responseID == "" { + return StoredResponse{}, false + } + s.ensureInitialized() + now = now.UTC() + s.pruneExpired(now) + record, ok := s.items[responseID] + if !ok { + return StoredResponse{}, false + } + if now.Sub(record.CreatedAt) > s.ttl { + delete(s.items, responseID) + return StoredResponse{}, false + } + return record, true +} + +func (s *responseStore) replaceAll(records map[string]StoredResponse) { + if s == nil { + return + } + s.ensureInitialized() + s.items = map[string]StoredResponse{} + s.expirations = responseExpiryHeap{} + heap.Init(&s.expirations) + for responseID, record := range records { + cleanID := strings.TrimSpace(responseID) + if cleanID == "" { + continue + } + createdAt := record.CreatedAt.UTC() + record.CreatedAt = createdAt + record.ConversationID = strings.TrimSpace(record.ConversationID) + record.ThreadID = strings.TrimSpace(record.ThreadID) + record.AccountEmail = strings.TrimSpace(record.AccountEmail) + s.items[cleanID] = record + heap.Push(&s.expirations, responseExpiryEntry{ + responseID: cleanID, + createdAt: createdAt, + }) + } +} + +func (s *responseStore) pruneExpired(now time.Time) int { + if s == nil { + return 0 + } + s.ensureInitialized() + if len(s.items) == 0 || len(s.expirations) == 0 { + return 0 + } + now = now.UTC() + removed := 0 + for len(s.expirations) > 0 { + top := s.expirations[0] + if now.Sub(top.createdAt) <= s.ttl { + break + } + entry, _ := heap.Pop(&s.expirations).(responseExpiryEntry) + if testHookResponseStorePrunePop != nil { + testHookResponseStorePrunePop() + } + current, ok := s.items[entry.responseID] + if !ok { + continue + } + if !current.CreatedAt.UTC().Equal(entry.createdAt) { + continue + } + delete(s.items, entry.responseID) + removed++ + } + return removed +} + +func (s *responseStore) deleteByConversationOrThread(conversationID string, threadID string) int { + if s == nil { + return 0 + } + conversationID = strings.TrimSpace(conversationID) + threadID = strings.TrimSpace(threadID) + if conversationID == "" && threadID == "" { + return 0 + } + s.ensureInitialized() + removed := 0 + for responseID, record := range s.items { + if (conversationID != "" && strings.TrimSpace(record.ConversationID) == conversationID) || + (threadID != "" && strings.TrimSpace(record.ThreadID) == threadID) { + delete(s.items, responseID) + removed++ + } + } + return removed +} diff --git a/internal/app/session_refresh.go b/internal/app/session_refresh.go index 85c0b2c..013b981 100644 --- a/internal/app/session_refresh.go +++ b/internal/app/session_refresh.go @@ -9,6 +9,11 @@ import ( "time" ) +var ( + testHookTryRefreshAccount func(context.Context, AppConfig, NotionAccount) (AppConfig, error) + testHookSaveAndApply func(*ServerState, AppConfig) error +) + func sessionRefreshNowISO() string { return time.Now().Format(time.RFC3339) } @@ -81,7 +86,7 @@ func loadSessionInfoForAccountRefresh(cfg AppConfig, account NotionAccount) (Ses func buildRefreshedSession(ctx context.Context, cfg AppConfig, account NotionAccount, prior SessionInfo) (SessionInfo, error) { upstream := cfg.NotionUpstream() resolver := NewProxyResolver(cfg) - session, err := newNotionLoginSession(helperTimeout(cfg), upstream, resolver, account.Email) + session, err := newNotionLoginSession(helperTimeout(cfg), upstream, resolver, account.Email, cfg) if err != nil { return SessionInfo{}, err } @@ -278,45 +283,62 @@ func (s *ServerState) RefreshSession(ctx context.Context, reason string) error { return fmt.Errorf("no active account configured for session refresh") } - updatedCfg, err := s.tryRefreshAccount(ctx, cfg, account) + tryRefresh := s.tryRefreshAccount + if testHookTryRefreshAccount != nil { + tryRefresh = testHookTryRefreshAccount + } + saveAndApply := s.SaveAndApply + if testHookSaveAndApply != nil { + saveAndApply = func(cfg AppConfig) error { + return testHookSaveAndApply(s, cfg) + } + } + + updatedCfg, err := tryRefresh(ctx, cfg, account) if err == nil { - if saveErr := s.SaveAndApply(updatedCfg); saveErr != nil { + if saveErr := saveAndApply(updatedCfg); saveErr != nil { s.setSessionRefreshRuntime(saveErr) return saveErr } + if s.DispatchProbeCache != nil { + s.DispatchProbeCache.invalidateAll() + } s.setSessionRefreshRuntime(nil) return nil } if !refreshCfg.AutoSwitch { s.setSessionRefreshRuntime(err) - _ = s.SaveAndApply(updatedCfg) + _ = saveAndApply(updatedCfg) return fmt.Errorf("refresh active account %s failed (%s): %w", account.Email, reason, err) } lastErr := err for _, candidate := range cfg.Accounts { - if canonicalEmailKey(candidate.Email) == canonicalEmailKey(account.Email) { + if getAccountEmailKey(candidate) == getAccountEmailKey(account) { continue } if !fileExists(ensureAccountPaths(cfg, candidate).ProbeJSON) { continue } - nextCfg, nextErr := s.tryRefreshAccount(ctx, updatedCfg, candidate) + nextCfg, nextErr := tryRefresh(ctx, updatedCfg, candidate) if nextErr != nil { lastErr = nextErr updatedCfg = nextCfg continue } - if saveErr := s.SaveAndApply(nextCfg); saveErr != nil { + if saveErr := saveAndApply(nextCfg); saveErr != nil { s.setSessionRefreshRuntime(saveErr) return saveErr } + if s.DispatchProbeCache != nil { + s.DispatchProbeCache.invalidateAll() + } s.setSessionRefreshRuntime(nil) return nil } - _ = s.SaveAndApply(updatedCfg) + _ = saveAndApply(updatedCfg) s.setSessionRefreshRuntime(lastErr) return fmt.Errorf("session refresh failed after trying active account and fallbacks (%s): %w", reason, lastErr) } diff --git a/internal/app/sqlite_store.go b/internal/app/sqlite_store.go index de3f91d..6dfb596 100644 --- a/internal/app/sqlite_store.go +++ b/internal/app/sqlite_store.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "path/filepath" + "runtime" "strings" "time" @@ -14,9 +15,18 @@ import ( type SQLiteStore struct { db *sql.DB + roDB *sql.DB path string } +func observeSQLiteDuration(op string, startedAt time.Time) { + if startedAt.IsZero() { + return + } + elapsed := time.Since(startedAt) + observeSQLiteOpDuration(op, elapsed) +} + func openSQLiteStore(cfg AppConfig) (*SQLiteStore, error) { path := strings.TrimSpace(cfg.ResolveSQLitePath()) if path == "" { @@ -33,11 +43,21 @@ func openSQLiteStore(cfg AppConfig) (*SQLiteStore, error) { return nil, fmt.Errorf("open sqlite: %w", err) } db.SetMaxOpenConns(1) + db.SetMaxIdleConns(1) store := &SQLiteStore{db: db, path: path} if err := store.init(); err != nil { _ = db.Close() return nil, err } + roDB, err := sql.Open("sqlite", fmt.Sprintf("file:%s?mode=ro&_journal=WAL", path)) + if err != nil { + _ = db.Close() + return nil, fmt.Errorf("open sqlite read-only: %w", err) + } + readers := maxInt(2, runtime.NumCPU()) + roDB.SetMaxOpenConns(readers) + roDB.SetMaxIdleConns(readers) + store.roDB = roDB return store, nil } @@ -49,10 +69,21 @@ func (s *SQLiteStore) Path() string { } func (s *SQLiteStore) Close() error { - if s == nil || s.db == nil { + if s == nil { return nil } - return s.db.Close() + var closeErr error + if s.roDB != nil { + if err := s.roDB.Close(); err != nil && closeErr == nil { + closeErr = err + } + } + if s.db != nil { + if err := s.db.Close(); err != nil && closeErr == nil { + closeErr = err + } + } + return closeErr } func (s *SQLiteStore) init() error { @@ -64,6 +95,10 @@ func (s *SQLiteStore) init() error { `PRAGMA busy_timeout=5000;`, `PRAGMA synchronous=NORMAL;`, `PRAGMA foreign_keys=ON;`, + `PRAGMA mmap_size=268435456;`, + `PRAGMA cache_size=-65536;`, + `PRAGMA temp_store=MEMORY;`, + `PRAGMA wal_autocheckpoint=1000;`, } for _, stmt := range pragmas { if _, err := s.db.Exec(stmt); err != nil { @@ -164,7 +199,19 @@ func (s *SQLiteStore) init() error { return nil } +func (s *SQLiteStore) readDB() *sql.DB { + if s == nil { + return nil + } + if s.roDB != nil { + return s.roDB + } + return s.db +} + func (s *SQLiteStore) SaveAccounts(cfg AppConfig) error { + startedAt := time.Now() + defer observeSQLiteDuration("save_accounts", startedAt) if s == nil || s.db == nil { return nil } @@ -190,7 +237,7 @@ func (s *SQLiteStore) SaveAccounts(cfg AppConfig) error { return err } active := 0 - if canonicalEmailKey(account.Email) == activeKey { + if getAccountEmailKey(account) == activeKey { active = 1 } if _, err = tx.Exec( @@ -211,10 +258,13 @@ func (s *SQLiteStore) SaveAccounts(cfg AppConfig) error { } func (s *SQLiteStore) LoadAccounts() ([]NotionAccount, string, bool, error) { - if s == nil || s.db == nil { + startedAt := time.Now() + defer observeSQLiteDuration("load_accounts", startedAt) + db := s.readDB() + if db == nil { return nil, "", false, nil } - rows, err := s.db.Query(`SELECT data_json, active FROM accounts ORDER BY position ASC, email ASC`) + rows, err := db.Query(`SELECT data_json, active FROM accounts ORDER BY position ASC, email ASC`) if err != nil { return nil, "", false, err } @@ -243,6 +293,8 @@ func (s *SQLiteStore) LoadAccounts() ([]NotionAccount, string, bool, error) { } func (s *SQLiteStore) SaveConversation(entry ConversationEntry) error { + startedAt := time.Now() + defer observeSQLiteDuration("save_conversation", startedAt) if s == nil || s.db == nil { return nil } @@ -268,6 +320,8 @@ func (s *SQLiteStore) SaveConversation(entry ConversationEntry) error { } func (s *SQLiteStore) DeleteConversation(id string) error { + startedAt := time.Now() + defer observeSQLiteDuration("delete_conversation", startedAt) if s == nil || s.db == nil || strings.TrimSpace(id) == "" { return nil } @@ -276,6 +330,8 @@ func (s *SQLiteStore) DeleteConversation(id string) error { } func (s *SQLiteStore) DeleteResponsesByConversationOrThread(conversationID string, threadID string) error { + startedAt := time.Now() + defer observeSQLiteDuration("delete_responses_by_conversation_or_thread", startedAt) if s == nil || s.db == nil { return nil } @@ -297,10 +353,13 @@ func (s *SQLiteStore) DeleteResponsesByConversationOrThread(conversationID strin } func (s *SQLiteStore) LoadConversations() ([]ConversationEntry, error) { - if s == nil || s.db == nil { + startedAt := time.Now() + defer observeSQLiteDuration("load_conversations", startedAt) + db := s.readDB() + if db == nil { return nil, nil } - rows, err := s.db.Query(`SELECT data_json FROM conversations ORDER BY updated_at DESC, created_at DESC LIMIT ?`, maxConversationEntries) + rows, err := db.Query(`SELECT data_json FROM conversations ORDER BY updated_at DESC, created_at DESC LIMIT ?`, maxConversationEntries) if err != nil { return nil, err } @@ -324,6 +383,8 @@ func (s *SQLiteStore) LoadConversations() ([]ConversationEntry, error) { } func (s *SQLiteStore) SaveResponse(responseID string, payload map[string]any, createdAt time.Time, conversationID string, threadID string, accountEmail string) error { + startedAt := time.Now() + defer observeSQLiteDuration("save_response", startedAt) if s == nil || s.db == nil || strings.TrimSpace(responseID) == "" { return nil } @@ -351,6 +412,8 @@ func (s *SQLiteStore) SaveResponse(responseID string, payload map[string]any, cr } func (s *SQLiteStore) DeleteExpiredResponses(ttl time.Duration) error { + startedAt := time.Now() + defer observeSQLiteDuration("delete_expired_responses", startedAt) if s == nil || s.db == nil || ttl <= 0 { return nil } @@ -360,13 +423,16 @@ func (s *SQLiteStore) DeleteExpiredResponses(ttl time.Duration) error { } func (s *SQLiteStore) LoadResponses(ttl time.Duration) (map[string]StoredResponse, error) { - if s == nil || s.db == nil { + startedAt := time.Now() + defer observeSQLiteDuration("load_responses", startedAt) + db := s.readDB() + if db == nil { return map[string]StoredResponse{}, nil } if err := s.DeleteExpiredResponses(ttl); err != nil { return nil, err } - rows, err := s.db.Query(`SELECT response_id, created_at, payload_json, conversation_id, thread_id, account_email FROM responses ORDER BY created_at DESC`) + rows, err := db.Query(`SELECT response_id, created_at, payload_json, conversation_id, thread_id, account_email FROM responses ORDER BY created_at DESC`) if err != nil { return nil, err } @@ -405,6 +471,8 @@ func (s *SQLiteStore) LoadResponses(ttl time.Duration) (map[string]StoredRespons } func (s *SQLiteStore) SaveConversationSession(session ConversationSession) error { + startedAt := time.Now() + defer observeSQLiteDuration("save_conversation_session", startedAt) if s == nil || s.db == nil || strings.TrimSpace(session.ID) == "" { return nil } @@ -451,6 +519,8 @@ func (s *SQLiteStore) SaveConversationSession(session ConversationSession) error } func (s *SQLiteStore) SaveConversationSessionStep(step ConversationSessionStep) error { + startedAt := time.Now() + defer observeSQLiteDuration("save_conversation_session_step", startedAt) if s == nil || s.db == nil || strings.TrimSpace(step.SessionID) == "" || strings.TrimSpace(step.UpdatedConfigID) == "" { return nil } @@ -489,10 +559,13 @@ func (s *SQLiteStore) LoadConversationSessionBySessionID(sessionID string) (Conv } func (s *SQLiteStore) loadConversationSession(query string, arg string) (ConversationSession, bool, error) { - if s == nil || s.db == nil || strings.TrimSpace(arg) == "" { + startedAt := time.Now() + defer observeSQLiteDuration("load_conversation_session", startedAt) + db := s.readDB() + if db == nil || strings.TrimSpace(arg) == "" { return ConversationSession{}, false, nil } - row := s.db.QueryRow(query, arg) + row := db.QueryRow(query, arg) var ( session ConversationSession createdAtText, updatedAtText, lastUsedAtText, deletedAtText string @@ -543,10 +616,13 @@ func (s *SQLiteStore) loadConversationSession(query string, arg string) (Convers } func (s *SQLiteStore) LoadConversationSessionStepIDs(sessionID string) ([]string, error) { - if s == nil || s.db == nil || strings.TrimSpace(sessionID) == "" { + startedAt := time.Now() + defer observeSQLiteDuration("load_conversation_session_step_ids", startedAt) + db := s.readDB() + if db == nil || strings.TrimSpace(sessionID) == "" { return nil, nil } - rows, err := s.db.Query(`SELECT updated_config_id FROM conversation_session_steps WHERE session_id = ? ORDER BY step_index ASC`, strings.TrimSpace(sessionID)) + rows, err := db.Query(`SELECT updated_config_id FROM conversation_session_steps WHERE session_id = ? ORDER BY step_index ASC`, strings.TrimSpace(sessionID)) if err != nil { return nil, err } @@ -563,6 +639,8 @@ func (s *SQLiteStore) LoadConversationSessionStepIDs(sessionID string) ([]string } func (s *SQLiteStore) MarkConversationSessionStatus(sessionID string, status string) error { + startedAt := time.Now() + defer observeSQLiteDuration("mark_conversation_session_status", startedAt) if s == nil || s.db == nil || strings.TrimSpace(sessionID) == "" { return nil } @@ -581,6 +659,8 @@ func (s *SQLiteStore) MarkConversationSessionStatus(sessionID string, status str } func (s *SQLiteStore) DeleteConversationSessionByConversationOrThread(conversationID string, threadID string) error { + startedAt := time.Now() + defer observeSQLiteDuration("delete_conversation_session_by_conversation_or_thread", startedAt) if s == nil || s.db == nil { return nil } @@ -644,6 +724,8 @@ func (s *SQLiteStore) DeleteConversationSessionByConversationOrThread(conversati } func (s *SQLiteStore) SaveSillyTavernBinding(binding SillyTavernBinding) error { + startedAt := time.Now() + defer observeSQLiteDuration("save_sillytavern_binding", startedAt) if s == nil || s.db == nil || strings.TrimSpace(binding.ConversationID) == "" { return nil } @@ -676,13 +758,16 @@ func (s *SQLiteStore) SaveSillyTavernBinding(binding SillyTavernBinding) error { } func (s *SQLiteStore) LoadRecentSillyTavernBindings(profileKey string, limit int) ([]SillyTavernBinding, error) { - if s == nil || s.db == nil || strings.TrimSpace(profileKey) == "" { + startedAt := time.Now() + defer observeSQLiteDuration("load_recent_sillytavern_bindings", startedAt) + db := s.readDB() + if db == nil || strings.TrimSpace(profileKey) == "" { return nil, nil } if limit <= 0 { limit = 12 } - rows, err := s.db.Query( + rows, err := db.Query( `SELECT conversation_id, profile_key, thread_id, account_email, mode, transcript_json, raw_message_count, updated_at FROM sillytavern_bindings WHERE profile_key = ? @@ -729,6 +814,8 @@ func (s *SQLiteStore) LoadRecentSillyTavernBindings(profileKey string, limit int } func (s *SQLiteStore) DeleteSillyTavernBinding(conversationID string) error { + startedAt := time.Now() + defer observeSQLiteDuration("delete_sillytavern_binding", startedAt) if s == nil || s.db == nil || strings.TrimSpace(conversationID) == "" { return nil } diff --git a/internal/app/sqlite_writer.go b/internal/app/sqlite_writer.go new file mode 100644 index 0000000..c6d28ef --- /dev/null +++ b/internal/app/sqlite_writer.go @@ -0,0 +1,203 @@ +package app + +import ( + "expvar" + "log" + "strings" + "sync" + "sync/atomic" + "time" +) + +const defaultSQLiteWriterQueueSize = 1024 + +var sqliteWriterFallbackTotalMetric = expvar.NewMap("notion2api_sqlite_writer_fallback_total") + +type sqlitePersistOpKind uint8 + +const ( + sqlitePersistOpSaveResponse sqlitePersistOpKind = iota + 1 + sqlitePersistOpDeleteResponsesByConversationOrThread +) + +type sqlitePersistOp struct { + kind sqlitePersistOpKind + responseID string + payload map[string]any + createdAt time.Time + conversationID string + threadID string + accountEmail string +} + +type SQLiteWriter struct { + store *SQLiteStore + queue chan sqlitePersistOp + done chan struct{} + ttlNanos atomic.Int64 + + mu sync.RWMutex + closed bool +} + +func newSQLiteWriter(store *SQLiteStore, ttl time.Duration) *SQLiteWriter { + if store == nil { + return nil + } + writer := &SQLiteWriter{ + store: store, + queue: make(chan sqlitePersistOp, defaultSQLiteWriterQueueSize), + done: make(chan struct{}), + } + writer.SetTTL(ttl) + go writer.run() + return writer +} + +func (w *SQLiteWriter) SetTTL(ttl time.Duration) { + if w == nil { + return + } + if ttl <= 0 { + ttl = time.Second + } + w.ttlNanos.Store(int64(ttl)) +} + +func (w *SQLiteWriter) EnqueueSaveResponse(responseID string, payload map[string]any, createdAt time.Time, conversationID string, threadID string, accountEmail string) { + if w == nil || w.store == nil { + return + } + responseID = strings.TrimSpace(responseID) + if responseID == "" { + return + } + op := sqlitePersistOp{ + kind: sqlitePersistOpSaveResponse, + responseID: responseID, + payload: clonePersistPayload(payload), + createdAt: createdAt, + conversationID: strings.TrimSpace(conversationID), + threadID: strings.TrimSpace(threadID), + accountEmail: strings.TrimSpace(accountEmail), + } + if w.tryEnqueue(op) { + return + } + sqliteWriterFallbackTotalMetric.Add("channel_full", 1) + w.apply(op) +} + +func (w *SQLiteWriter) EnqueueDeleteResponsesByConversationOrThread(conversationID string, threadID string) { + if w == nil || w.store == nil { + return + } + conversationID = strings.TrimSpace(conversationID) + threadID = strings.TrimSpace(threadID) + if conversationID == "" && threadID == "" { + return + } + op := sqlitePersistOp{ + kind: sqlitePersistOpDeleteResponsesByConversationOrThread, + conversationID: conversationID, + threadID: threadID, + } + if w.enqueueBlocking(op) { + return + } + sqliteWriterFallbackTotalMetric.Add("writer_unavailable", 1) + w.apply(op) +} + +func (w *SQLiteWriter) Close() { + if w == nil { + return + } + w.mu.Lock() + if w.closed { + w.mu.Unlock() + return + } + w.closed = true + close(w.queue) + w.mu.Unlock() + <-w.done +} + +func (w *SQLiteWriter) tryEnqueue(op sqlitePersistOp) bool { + w.mu.RLock() + defer w.mu.RUnlock() + if w.closed { + return false + } + select { + case w.queue <- op: + return true + default: + return false + } +} + +func (w *SQLiteWriter) enqueueBlocking(op sqlitePersistOp) bool { + w.mu.RLock() + defer w.mu.RUnlock() + if w.closed { + return false + } + select { + case w.queue <- op: + return true + default: + w.queue <- op + return true + } +} + +func (w *SQLiteWriter) run() { + defer close(w.done) + for op := range w.queue { + w.apply(op) + } +} + +func (w *SQLiteWriter) apply(op sqlitePersistOp) { + if w == nil || w.store == nil { + return + } + switch op.kind { + case sqlitePersistOpSaveResponse: + if err := w.store.SaveResponse(op.responseID, op.payload, op.createdAt, op.conversationID, op.threadID, op.accountEmail); err != nil { + log.Printf("[sqlite-writer] save response %s failed: %v", op.responseID, err) + return + } + if err := w.store.DeleteExpiredResponses(w.ttl()); err != nil { + log.Printf("[sqlite-writer] cleanup responses failed: %v", err) + } + case sqlitePersistOpDeleteResponsesByConversationOrThread: + if err := w.store.DeleteResponsesByConversationOrThread(op.conversationID, op.threadID); err != nil { + log.Printf("[sqlite-writer] delete responses conversation=%s thread=%s failed: %v", op.conversationID, op.threadID, err) + } + } +} + +func (w *SQLiteWriter) ttl() time.Duration { + if w == nil { + return time.Second + } + ttlNanos := w.ttlNanos.Load() + if ttlNanos <= 0 { + return time.Second + } + return time.Duration(ttlNanos) +} + +func clonePersistPayload(src map[string]any) map[string]any { + if len(src) == 0 { + return nil + } + dst := make(map[string]any, len(src)) + for key, value := range src { + dst[key] = value + } + return dst +} diff --git a/internal/wreq/doc.go b/internal/wreq/doc.go deleted file mode 100644 index 614a60b..0000000 --- a/internal/wreq/doc.go +++ /dev/null @@ -1 +0,0 @@ -package wreq diff --git a/internal/wreq/wreq_cgo.go b/internal/wreq/wreq_cgo.go deleted file mode 100644 index e18f1e6..0000000 --- a/internal/wreq/wreq_cgo.go +++ /dev/null @@ -1,124 +0,0 @@ -//go:build wreq_ffi - -package wreq - -/* -#cgo CFLAGS: -I${SRCDIR}/../../wreq-ffi/include -#cgo LDFLAGS: ${SRCDIR}/../../wreq-ffi/target/release/libwreq_ffi.a -ldl -lm -lpthread - -#include -#include "wreq_ffi.h" -*/ -import "C" - -import ( - "context" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "runtime" - "sync/atomic" - "unsafe" -) - -type ClientConfig struct { - Emulation string `json:"emulation,omitempty"` - TimeoutSecs uint64 `json:"timeout_secs,omitempty"` - CookieStore *bool `json:"cookie_store,omitempty"` - ProxyURL string `json:"proxy_url,omitempty"` - AcceptInvalidCerts *bool `json:"accept_invalid_certs,omitempty"` -} - -type RequestSpec struct { - Method string `json:"method"` - URL string `json:"url"` - Headers [][]string `json:"headers,omitempty"` - BodyB64 string `json:"body_b64,omitempty"` - TimeoutSecs uint64 `json:"timeout_secs,omitempty"` -} - -type Response struct { - OK bool `json:"ok"` - Status int `json:"status"` - Headers [][]string `json:"headers"` - BodyB64 string `json:"body_b64"` - FinalURL string `json:"final_url"` - Error string `json:"error,omitempty"` -} - -func (r *Response) Body() ([]byte, error) { - if r.BodyB64 == "" { - return nil, nil - } - return base64.StdEncoding.DecodeString(r.BodyB64) -} - -type Client struct { - handle *C.struct_WreqClient - closed atomic.Bool -} - -func New(cfg ClientConfig) (*Client, error) { - profile, err := json.Marshal(cfg) - if err != nil { - return nil, fmt.Errorf("wreq: marshal config: %w", err) - } - cProfile := C.CString(string(profile)) - defer C.free(unsafe.Pointer(cProfile)) - - handle := C.wreq_client_new(cProfile) - if handle == nil { - return nil, errors.New("wreq: wreq_client_new returned NULL (bad config?)") - } - c := &Client{handle: handle} - runtime.SetFinalizer(c, func(c *Client) { _ = c.Close() }) - return c, nil -} - -func (c *Client) Close() error { - if c == nil || !c.closed.CompareAndSwap(false, true) { - return nil - } - C.wreq_client_free(c.handle) - c.handle = nil - runtime.SetFinalizer(c, nil) - return nil -} - -func (c *Client) Do(ctx context.Context, spec RequestSpec) (*Response, error) { - if c == nil || c.handle == nil || c.closed.Load() { - return nil, errors.New("wreq: client closed") - } - if err := ctx.Err(); err != nil { - return nil, err - } - - reqJSON, err := json.Marshal(spec) - if err != nil { - return nil, fmt.Errorf("wreq: marshal request: %w", err) - } - - cReq := C.CString(string(reqJSON)) - defer C.free(unsafe.Pointer(cReq)) - - cResp := C.wreq_request(c.handle, cReq) - if cResp == nil { - return nil, errors.New("wreq: wreq_request returned NULL") - } - defer C.wreq_string_free(cResp) - - goResp := C.GoString(cResp) - var resp Response - if err := json.Unmarshal([]byte(goResp), &resp); err != nil { - return nil, fmt.Errorf("wreq: unmarshal response: %w", err) - } - if !resp.OK { - return &resp, fmt.Errorf("wreq: %s", resp.Error) - } - return &resp, nil -} - -func Version() string { - return C.GoString(C.wreq_ffi_version()) -} diff --git a/internal/wreq/wreq_stub.go b/internal/wreq/wreq_stub.go deleted file mode 100644 index ae6e862..0000000 --- a/internal/wreq/wreq_stub.go +++ /dev/null @@ -1,49 +0,0 @@ -//go:build !wreq_ffi - -package wreq - -import ( - "context" - "errors" -) - -var ErrNotLinked = errors.New("wreq: built without wreq_ffi tag; use node-wreq fallback") - -type ClientConfig struct { - Emulation string `json:"emulation,omitempty"` - TimeoutSecs uint64 `json:"timeout_secs,omitempty"` - CookieStore *bool `json:"cookie_store,omitempty"` - ProxyURL string `json:"proxy_url,omitempty"` - AcceptInvalidCerts *bool `json:"accept_invalid_certs,omitempty"` -} - -type RequestSpec struct { - Method string `json:"method"` - URL string `json:"url"` - Headers [][]string `json:"headers,omitempty"` - BodyB64 string `json:"body_b64,omitempty"` - TimeoutSecs uint64 `json:"timeout_secs,omitempty"` -} - -type Response struct { - OK bool `json:"ok"` - Status int `json:"status"` - Headers [][]string `json:"headers"` - BodyB64 string `json:"body_b64"` - FinalURL string `json:"final_url"` - Error string `json:"error,omitempty"` -} - -func (r *Response) Body() ([]byte, error) { return nil, ErrNotLinked } - -type Client struct{} - -func New(_ ClientConfig) (*Client, error) { return nil, ErrNotLinked } - -func (c *Client) Close() error { return nil } - -func (c *Client) Do(_ context.Context, _ RequestSpec) (*Response, error) { - return nil, ErrNotLinked -} - -func Version() string { return "unlinked" } diff --git a/scripts/perf/baseline.sh b/scripts/perf/baseline.sh new file mode 100644 index 0000000..d1175de --- /dev/null +++ b/scripts/perf/baseline.sh @@ -0,0 +1,299 @@ +#!/usr/bin/env bash + +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" +PAYLOAD_TEMPLATE="${ROOT_DIR}/scripts/perf/payload-chat.json" +OUT_ROOT="${ROOT_DIR}/docs/perf" + +BASE_URL="${N2A_BASE_URL:-http://127.0.0.1:8787}" +PPROF_BASE="${N2A_PPROF_BASE:-http://127.0.0.1:6060}" +API_KEY="${N2A_API_KEY:-change-me-openai-key}" +CONCURRENCY="${N2A_PERF_CONCURRENCY:-50}" +DURATION="${N2A_PERF_DURATION:-60s}" + +if [[ ! -f "${PAYLOAD_TEMPLATE}" ]]; then + echo "payload template not found: ${PAYLOAD_TEMPLATE}" >&2 + exit 1 +fi + +if ! command -v curl >/dev/null 2>&1; then + echo "curl is required" >&2 + exit 1 +fi + +if ! command -v python >/dev/null 2>&1; then + echo "python is required" >&2 + exit 1 +fi + +LOAD_TOOL="" +if command -v hey >/dev/null 2>&1; then + LOAD_TOOL="hey" +elif command -v vegeta >/dev/null 2>&1; then + LOAD_TOOL="vegeta" +else + echo "either hey or vegeta must be installed" >&2 + exit 1 +fi + +if ! curl -fsS "${BASE_URL}/healthz" >/dev/null; then + echo "service is not reachable: ${BASE_URL}/healthz" >&2 + exit 1 +fi + +if ! curl -fsS "${PPROF_BASE}/debug/pprof/" >/dev/null; then + echo "pprof endpoint is not reachable: ${PPROF_BASE}/debug/pprof/" >&2 + echo "enable config.debug.pprof_enabled=true and keep pprof_addr local-only." >&2 + exit 1 +fi + +stamp="$(date -u +%Y%m%d-%H%M%S)" +git_sha="$(git -C "${ROOT_DIR}" rev-parse --short HEAD 2>/dev/null || echo "nogit")" +OUT_DIR="${OUT_ROOT}/${stamp}-${git_sha}" +mkdir -p "${OUT_DIR}" + +tmp_dir="$(mktemp -d "${OUT_DIR}/tmp.XXXXXX")" +cleanup() { + rm -rf "${tmp_dir}" +} +trap cleanup EXIT + +non_stream_payload="${tmp_dir}/chat.nonstream.json" +stream_payload="${tmp_dir}/chat.stream.json" +python - "${PAYLOAD_TEMPLATE}" "${non_stream_payload}" "${stream_payload}" <<'PY' +import json +import sys +from pathlib import Path + +template = Path(sys.argv[1]) +non_stream = Path(sys.argv[2]) +stream = Path(sys.argv[3]) +payload = json.loads(template.read_text(encoding="utf-8")) + +payload_non_stream = dict(payload) +payload_non_stream["stream"] = False +non_stream.write_text(json.dumps(payload_non_stream, ensure_ascii=False), encoding="utf-8") + +payload_stream = dict(payload) +payload_stream["stream"] = True +stream.write_text(json.dumps(payload_stream, ensure_ascii=False), encoding="utf-8") +PY + +request_url="${BASE_URL}/v1/chat/completions" +auth_header="Authorization: Bearer ${API_KEY}" +content_header="Content-Type: application/json" + +run_hey() { + local payload_file="$1" + local output_file="$2" + hey -z "${DURATION}" -c "${CONCURRENCY}" -m POST \ + -H "${auth_header}" \ + -H "${content_header}" \ + -D "${payload_file}" \ + "${request_url}" >"${output_file}" +} + +run_vegeta() { + local payload_file="$1" + local output_txt="$2" + local output_bin="$3" + local output_json="$4" + printf "POST %s\n" "${request_url}" | vegeta attack \ + -duration="${DURATION}" \ + -workers="${CONCURRENCY}" \ + -max-workers="${CONCURRENCY}" \ + -body="${payload_file}" \ + -header="${auth_header}" \ + -header="${content_header}" >"${output_bin}" + vegeta report "${output_bin}" >"${output_txt}" + vegeta report -type=json "${output_bin}" >"${output_json}" +} + +extract_percentile() { + local report_file="$1" + local percentile="$2" + awk -v p="${percentile}" '$1==p {print $3 " " $4}' "${report_file}" | head -n1 +} + +extract_requests_per_sec() { + local report_file="$1" + awk '$1=="Requests/sec:" {print $2}' "${report_file}" | head -n1 +} + +rss_bytes() { + local pid="$1" + if [[ -f "/proc/${pid}/status" ]]; then + awk '/VmRSS:/ {print $2*1024; exit}' "/proc/${pid}/status" + return + fi + if command -v ps >/dev/null 2>&1; then + local rss_kb + rss_kb="$(ps -o rss= -p "${pid}" | awk '{print $1}' | head -n1 || true)" + if [[ -n "${rss_kb}" ]]; then + echo $((rss_kb * 1024)) + return + fi + fi + echo "" +} + +peak_rss_bytes=0 +stream_pid="" + +if [[ "${LOAD_TOOL}" == "hey" ]]; then + run_hey "${non_stream_payload}" "${OUT_DIR}/nonstream-hey.txt" + + run_hey "${stream_payload}" "${OUT_DIR}/stream-hey.txt" & + stream_pid=$! + + for _ in $(seq 1 10); do + if ! kill -0 "${stream_pid}" >/dev/null 2>&1; then + break + fi + current_rss="$(rss_bytes "${stream_pid}" || true)" + if [[ -n "${current_rss}" ]] && (( current_rss > peak_rss_bytes )); then + peak_rss_bytes="${current_rss}" + fi + sleep 1 + done + + curl -fsS "${PPROF_BASE}/debug/pprof/profile?seconds=30" -o "${OUT_DIR}/cpu.pb.gz" + curl -fsS "${PPROF_BASE}/debug/pprof/heap" -o "${OUT_DIR}/heap.pb.gz" + curl -fsS "${PPROF_BASE}/debug/pprof/goroutine?debug=0" -o "${OUT_DIR}/goroutine.pb.gz" + wait "${stream_pid}" +else + run_vegeta "${non_stream_payload}" "${OUT_DIR}/nonstream-vegeta.txt" "${OUT_DIR}/nonstream-vegeta.bin" "${OUT_DIR}/nonstream-vegeta.json" + run_vegeta "${stream_payload}" "${OUT_DIR}/stream-vegeta.txt" "${OUT_DIR}/stream-vegeta.bin" "${OUT_DIR}/stream-vegeta.json" & + stream_pid=$! + + for _ in $(seq 1 10); do + if ! kill -0 "${stream_pid}" >/dev/null 2>&1; then + break + fi + current_rss="$(rss_bytes "${stream_pid}" || true)" + if [[ -n "${current_rss}" ]] && (( current_rss > peak_rss_bytes )); then + peak_rss_bytes="${current_rss}" + fi + sleep 1 + done + + curl -fsS "${PPROF_BASE}/debug/pprof/profile?seconds=30" -o "${OUT_DIR}/cpu.pb.gz" + curl -fsS "${PPROF_BASE}/debug/pprof/heap" -o "${OUT_DIR}/heap.pb.gz" + curl -fsS "${PPROF_BASE}/debug/pprof/goroutine?debug=0" -o "${OUT_DIR}/goroutine.pb.gz" + wait "${stream_pid}" +fi + +if [[ "${LOAD_TOOL}" == "hey" ]]; then + nonstream_report="${OUT_DIR}/nonstream-hey.txt" + stream_report="${OUT_DIR}/stream-hey.txt" + nonstream_p50="$(extract_percentile "${nonstream_report}" "50%" || true)" + nonstream_p95="$(extract_percentile "${nonstream_report}" "95%" || true)" + nonstream_p99="$(extract_percentile "${nonstream_report}" "99%" || true)" + stream_p50="$(extract_percentile "${stream_report}" "50%" || true)" + stream_p95="$(extract_percentile "${stream_report}" "95%" || true)" + stream_p99="$(extract_percentile "${stream_report}" "99%" || true)" + nonstream_rps="$(extract_requests_per_sec "${nonstream_report}" || true)" + stream_rps="$(extract_requests_per_sec "${stream_report}" || true)" +else + nonstream_report="${OUT_DIR}/nonstream-vegeta.txt" + stream_report="${OUT_DIR}/stream-vegeta.txt" + eval "$( + python - "${OUT_DIR}/nonstream-vegeta.json" "${OUT_DIR}/stream-vegeta.json" <<'PY' +import json +import sys + +def fmt_ns(value): + if value is None: + return "" + ns = float(value) + if ns >= 1_000_000_000: + return f"{ns/1_000_000_000:.4f} s" + if ns >= 1_000_000: + return f"{ns/1_000_000:.4f} ms" + if ns >= 1_000: + return f"{ns/1_000:.4f} us" + return f"{ns:.0f} ns" + +def load(path): + with open(path, "r", encoding="utf-8") as f: + return json.load(f) + +def pick_latency(report, key): + lat = report.get("latencies", {}) + if key in lat: + return lat[key] + fallback = { + "50th": "50", + "95th": "95", + "99th": "99", + }.get(key) + return lat.get(fallback) + +non = load(sys.argv[1]) +st = load(sys.argv[2]) + +print(f"nonstream_p50={fmt_ns(pick_latency(non, '50th'))!r}") +print(f"nonstream_p95={fmt_ns(pick_latency(non, '95th'))!r}") +print(f"nonstream_p99={fmt_ns(pick_latency(non, '99th'))!r}") +print(f"stream_p50={fmt_ns(pick_latency(st, '50th'))!r}") +print(f"stream_p95={fmt_ns(pick_latency(st, '95th'))!r}") +print(f"stream_p99={fmt_ns(pick_latency(st, '99th'))!r}") +print(f"nonstream_rps={str(non.get('throughput', ''))!r}") +print(f"stream_rps={str(st.get('throughput', ''))!r}") +PY +)" +fi + +rss_mib="n/a" +if (( peak_rss_bytes > 0 )); then + rss_mib="$(python - <"${OUT_DIR}/summary.md" < = OnceCell::new(); - -fn runtime() -> &'static Runtime { - RUNTIME.get_or_init(|| { - tokio::runtime::Builder::new_multi_thread() - .enable_all() - .thread_name("wreq-ffi") - .build() - .expect("wreq-ffi: failed to create tokio runtime") - }) -} - - -pub struct WreqClient { - inner: wreq::Client, -} - - -#[derive(Default, Deserialize)] -struct ClientConfig { - #[serde(default)] - emulation: Option, - #[serde(default)] - timeout_secs: Option, - #[serde(default)] - proxy_url: Option, -} - -#[derive(Deserialize)] -struct RequestSpec { - method: String, - url: String, - #[serde(default)] - headers: Vec<(String, String)>, - #[serde(default)] - body_b64: Option, - #[serde(default)] - timeout_secs: Option, -} - -#[derive(Serialize)] -struct ResponseEnvelope { - ok: bool, - status: u16, - headers: Vec<(String, String)>, - body_b64: String, - final_url: String, -} - -#[derive(Serialize)] -struct ErrorEnvelope<'a> { - ok: bool, - error: &'a str, -} - -#[no_mangle] -pub unsafe extern "C" fn wreq_client_new(profile_json: *const c_char) -> *mut WreqClient { - catch_unwind(AssertUnwindSafe(|| { - let cfg: ClientConfig = if profile_json.is_null() { - ClientConfig::default() - } else { - let s = match CStr::from_ptr(profile_json).to_str() { - Ok(s) => s, - Err(_) => return ptr::null_mut::(), - }; - match serde_json::from_str(s) { - Ok(c) => c, - Err(_) => return ptr::null_mut(), - } - }; - - let mut builder = wreq::Client::builder(); - if let Some(secs) = cfg.timeout_secs { - builder = builder.timeout(Duration::from_secs(secs)); - } - if let Some(p) = cfg.proxy_url.as_deref() { - if let Ok(proxy) = wreq::Proxy::all(p) { - builder = builder.proxy(proxy); - } - } - let _ = cfg.emulation; - - match builder.build() { - Ok(client) => Box::into_raw(Box::new(WreqClient { inner: client })), - Err(_) => ptr::null_mut(), - } - })) - .unwrap_or(ptr::null_mut()) -} - -#[no_mangle] -pub unsafe extern "C" fn wreq_client_free(client: *mut WreqClient) { - if !client.is_null() { - drop(Box::from_raw(client)); - } -} - -#[no_mangle] -pub unsafe extern "C" fn wreq_request( - client: *mut WreqClient, - request_json: *const c_char, -) -> *mut c_char { - catch_unwind(AssertUnwindSafe(|| { - if client.is_null() || request_json.is_null() { - return error_response("nil client or request"); - } - let client = &*client; - let raw = match CStr::from_ptr(request_json).to_str() { - Ok(s) => s, - Err(_) => return error_response("request_json: invalid utf-8"), - }; - let spec: RequestSpec = match serde_json::from_str(raw) { - Ok(s) => s, - Err(e) => return error_response(&format!("request_json: {e}")), - }; - runtime().block_on(do_request(&client.inner, spec)) - })) - .unwrap_or_else(|_| error_response("rust panic in wreq_request")) -} - -#[no_mangle] -pub unsafe extern "C" fn wreq_string_free(ptr: *mut c_char) { - if !ptr.is_null() { - drop(CString::from_raw(ptr)); - } -} - -#[no_mangle] -pub extern "C" fn wreq_ffi_version() -> *const c_char { - static VERSION: &[u8] = concat!(env!("CARGO_PKG_VERSION"), "\0").as_bytes(); - VERSION.as_ptr() as *const c_char -} - - -async fn do_request(client: &wreq::Client, spec: RequestSpec) -> *mut c_char { - let method = match spec.method.parse::() { - Ok(m) => m, - Err(e) => return error_response(&format!("bad method: {e}")), - }; - let mut req = client.request(method, &spec.url); - for (k, v) in &spec.headers { - req = req.header(k.as_str(), v.as_str()); - } - if let Some(secs) = spec.timeout_secs { - req = req.timeout(Duration::from_secs(secs)); - } - if let Some(b64) = spec.body_b64.as_deref() { - if !b64.is_empty() { - match base64_decode(b64) { - Ok(bytes) => req = req.body(bytes), - Err(e) => return error_response(&format!("body_b64: {e}")), - } - } - } - - let resp = match req.send().await { - Ok(r) => r, - Err(e) => return error_response(&format!("send: {e}")), - }; - let status = resp.status().as_u16(); - let final_url = resp.url().to_string(); - let mut headers: Vec<(String, String)> = Vec::with_capacity(resp.headers().len()); - for (k, v) in resp.headers().iter() { - headers.push(( - k.as_str().to_string(), - v.to_str().unwrap_or("").to_string(), - )); - } - let bytes = match resp.bytes().await { - Ok(b) => b, - Err(e) => return error_response(&format!("read body: {e}")), - }; - let env = ResponseEnvelope { - ok: true, - status, - headers, - body_b64: base64_encode(&bytes), - final_url, - }; - json_to_c_string(&env) -} - -fn error_response(msg: &str) -> *mut c_char { - let env = ErrorEnvelope { ok: false, error: msg }; - json_to_c_string(&env) -} - -fn json_to_c_string(value: &T) -> *mut c_char { - let s = match serde_json::to_string(value) { - Ok(s) => s, - Err(_) => String::from("{\"ok\":false,\"error\":\"json serialize failed\"}"), - }; - match CString::new(s) { - Ok(c) => c.into_raw(), - Err(_) => ptr::null_mut(), - } -} - - -const B64_ALPHABET: &[u8; 64] = - b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; - -fn base64_encode(bytes: &[u8]) -> String { - let mut out = String::with_capacity(bytes.len().div_ceil(3) * 4); - let mut i = 0; - while i + 3 <= bytes.len() { - let n = ((bytes[i] as u32) << 16) - | ((bytes[i + 1] as u32) << 8) - | (bytes[i + 2] as u32); - out.push(B64_ALPHABET[((n >> 18) & 0x3F) as usize] as char); - out.push(B64_ALPHABET[((n >> 12) & 0x3F) as usize] as char); - out.push(B64_ALPHABET[((n >> 6) & 0x3F) as usize] as char); - out.push(B64_ALPHABET[(n & 0x3F) as usize] as char); - i += 3; - } - let rem = bytes.len() - i; - if rem == 1 { - let n = (bytes[i] as u32) << 16; - out.push(B64_ALPHABET[((n >> 18) & 0x3F) as usize] as char); - out.push(B64_ALPHABET[((n >> 12) & 0x3F) as usize] as char); - out.push('='); - out.push('='); - } else if rem == 2 { - let n = ((bytes[i] as u32) << 16) | ((bytes[i + 1] as u32) << 8); - out.push(B64_ALPHABET[((n >> 18) & 0x3F) as usize] as char); - out.push(B64_ALPHABET[((n >> 12) & 0x3F) as usize] as char); - out.push(B64_ALPHABET[((n >> 6) & 0x3F) as usize] as char); - out.push('='); - } - out -} - -fn base64_decode(input: &str) -> Result, &'static str> { - let mut buf = Vec::with_capacity(input.len() * 3 / 4); - let mut bits: u32 = 0; - let mut nbits: u32 = 0; - for c in input.bytes() { - let v: u32 = match c { - b'A'..=b'Z' => (c - b'A') as u32, - b'a'..=b'z' => (c - b'a') as u32 + 26, - b'0'..=b'9' => (c - b'0') as u32 + 52, - b'+' | b'-' => 62, - b'/' | b'_' => 63, - b'=' | b'\n' | b'\r' | b' ' | b'\t' => continue, - _ => return Err("invalid base64 char"), - }; - bits = (bits << 6) | v; - nbits += 6; - if nbits >= 8 { - nbits -= 8; - buf.push(((bits >> nbits) & 0xFF) as u8); - } - } - Ok(buf) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn b64_roundtrip() { - for case in [&b""[..], b"f", b"fo", b"foo", b"foob", b"fooba", b"foobar"] { - let enc = base64_encode(case); - let dec = base64_decode(&enc).unwrap(); - assert_eq!(dec, case); - } - } -}