diff --git a/docs/gateway-sse-proxy-plan.md b/docs/gateway-sse-proxy-plan.md new file mode 100644 index 0000000..912afe9 --- /dev/null +++ b/docs/gateway-sse-proxy-plan.md @@ -0,0 +1,79 @@ +# Gateway 层新增 `/chat/stream` 与 `/chat/audio/stream` SSE 转发计划 + +## 背景 +- 现状:前端只通过 WebSocket 与 `output-handler` 通信,最终收到的是整段文本;`dialog-engine` 已提供 `POST /chat/stream` 与 `POST /chat/audio/stream` 的 SSE 流式接口,但未经 Gateway 暴露。 +- 目标:在 Gateway(FastAPI)统一暴露文本与音频 SSE 端点,集中处理 CORS、鉴权、速率限制及日志监控,使前端只需访问 Gateway。 + +## 范围 +1. Gateway 新增 `POST /chat/stream` 与 `POST /chat/audio/stream`。 +2. 请求透传至 `dialog-engine` 对应 SSE 接口,响应体保持原始 SSE 事件顺序与格式。 +3. 统一处理 CORS、鉴权、错误响应与观测性。 +4. 提供最少的前端调整指导(URL、Header)。 +5. 不在本阶段处理输出层 WebSocket 协议变更。 + +## 设计要点 + +### 1. HTTP/SSE 代理实现 +- **请求封装**:复用现有 httpx AsyncClient;`timeout` 设为较大值或 `None`,避免流式超时。 +- **流式转发**:使用 `httpx.AsyncClient.stream()` 获取响应 `aiter_raw()`,通过 `StreamingResponse` 将字节块原样写回客户端;保持 `text/event-stream`、`cache-control: no-cache`、`connection: keep-alive` 等头部。 +- **头部与重写**:保留必要的 `Authorization`/`sessionId`;可根据需要新增 `X-Forwarded-*`。 +- **错误处理**:上游返回非 2xx 时,终止流并返回 JSON 错误;请求阶段异常(连接失败、超时)转换为 5xx 并记录。 +- **可取消性**:客户端断开时及时关闭 httpx 流,避免资源泄漏。 + +### 2. CORS 与鉴权 +- **CORS**:复用 FastAPI CORS 中间件配置;确认需要暴露 SSE 特定头(如 `content-type`)。 +- **鉴权策略**: + - 支持现有 JWT/Token 验证(若尚未实现,至少保留钩子供后续扩展)。 + - 将用户身份或 session 信息透传至 `dialog-engine`,用于上下文关联。 +- **速率限制/配额**:预留集成点,例如基于 Redis 的速率限制器;本阶段可先记录需求。 + +### 3. 日志与监控 +- **日志**:记录请求入口、sessionId、目标 URL、状态码、耗时;对异常进行结构化日志。 +- **指标**:埋点统计活跃 SSE 连接数、平均持续时间、错误率;可写入 Prometheus exporter(若无可先准备接口)。 +- **Tracing**:如启用 OpenTelemetry,确保上下文跨服务传递。 + +### 4. 配置与部署 +- 新增环境变量:`DIALOG_ENGINE_BASE_HTTP`(或沿用已有 URL 推导)。 +- 更新 docker-compose/Helm 等部署文件,确保 Gateway 能访问 `dialog-engine` HTTP 端口。 +- 文档更新 `.env.example` 说明新的可配置项。 + +### 5. 测试策略 +- **单元测试**:对 Gateway 新增路由编写测试,验证: + - 请求体透传 + - SSE 头部正确 + - 上游错误转换(4xx/5xx) +- **集成测试**:本地 compose 启动后,使用测试客户端命中 `/chat/stream`,确认 `text-delta`、`done` 事件完整。 +- **回归测试**:验证现有 `/api/asr`、WebSocket 代理功能不受影响。 +- **手动验证**:前端或 curl `-N` 命令连接 Gateway,观察流式输出;模拟网络中断确保资源释放。 + +### 6. 推广步骤 +1. **开发阶段**:在 feature 分支实现并通过本地测试。 +2. **代码审查**:重点关注资源释放、超时、错误映射。 +3. **预发布环境验证**:与前端联调,确认字幕/音频流实际可用。 +4. **灰度发布**:按环境依次上线,监控连接数与错误率。 +5. **文档更新**:README / 接口文档补充新的 Gateway 端点说明。 + +### 7. 前端配合 +- 将 SSE 连接目标改为 Gateway `/chat/stream` 或 `/chat/audio/stream`。 +- 如果依赖 `Authorization` 头,确保在请求中发送;匹配 Gateway 新增的鉴权策略。 +- 处理连接断开时的重试与 UI 反馈。 + +## 风险与缓解 +- **SSE 与 ASGI 兼容性**:FastAPI + Uvicorn 需使用 `--http=httptools` 或 `--http=h11`;如遇服务器缓存问题,需配置 `proxy_set_header Connection keep-alive` 等反向代理设置。 +- **长连接资源占用**:评估并调整 Gateway worker 数、连接池大小(httpx `limits`);必要时采用负载均衡分流。 +- **鉴权未就绪**:若短期内无法实现 Token 校验,至少在接口说明中标记风险,规划后续迭代。 +- **跨域设置不足**:若前端来源不止 3000 端口,需要同步更新 `allow_origins` 列表或改为通配策略。 + +## 里程碑(预估) +| 阶段 | 内容 | 负责人 | 预估耗时 | +| --- | --- | --- | --- | +| 设计确认 | 评审本方案、确认鉴权需求 | 后端、前端 | 0.5 天 | +| 开发实现 | Gateway 新增端点 + 配置 | 后端 | 1.5 天 | +| 测试联调 | 单元/集成/前端联调 | 后端 + 前端 | 1 天 | +| 上线与监控 | 发布到生产、监控稳定性 | DevOps | 0.5 天 | + +## 产出 +- Gateway 新增 SSE 代理代码与测试。 +- 更新后的配置文件、Docker/Helm。 +- 接口文档及使用指南。 +- (可选)监控仪表盘或报警规则。 diff --git a/requirements-dev.txt b/requirements-dev.txt index 5c5eecf..76780ec 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -28,8 +28,8 @@ Werkzeug==2.3.7 psutil==5.9.6 # ===== 测试框架 ===== -pytest==8.0.0 -pytest-asyncio==0.23.0 +pytest==8.4.2 +pytest-asyncio==1.2.0 pytest-mock==3.12.0 # ===== 开发工具 ===== diff --git a/services/gateway-python/main.py b/services/gateway-python/main.py index 79b1244..aba3aac 100644 --- a/services/gateway-python/main.py +++ b/services/gateway-python/main.py @@ -3,17 +3,17 @@ import os from contextlib import asynccontextmanager from typing import Dict +from urllib.parse import urlparse, urlunparse -import websockets import httpx -from urllib.parse import urlparse, urlunparse -from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException -from src.services.asr_routes import bp_asr as _flask_bp # type: ignore +import uvicorn +import websockets +from fastapi import FastAPI, HTTPException, Request, WebSocket, WebSocketDisconnect +from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.wsgi import WSGIMiddleware +from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse from flask import Flask as _Flask # shim for mounting Flask blueprint -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import HTMLResponse, JSONResponse -import uvicorn +from src.services.asr_routes import bp_asr as _flask_bp # type: ignore # 配置日志 logging.basicConfig( @@ -28,6 +28,9 @@ "output": os.getenv("OUTPUT_HANDLER_URL", "ws://localhost:8002") } +DIALOG_ENGINE_URL = os.getenv("DIALOG_ENGINE_URL", "http://dialog-engine:8100") +SSE_TIMEOUT = httpx.Timeout(60.0, connect=5.0, read=None, write=10.0) + # 活跃连接跟踪 active_connections: Dict[str, WebSocket] = {} @@ -175,6 +178,91 @@ def _output_http_base() -> str: return urlunparse(http_parsed).rstrip("/") +def _dialog_engine_http_base() -> str: + return DIALOG_ENGINE_URL.rstrip("/") + + +def _build_forward_headers(request: Request) -> Dict[str, str]: + excluded = {"host", "content-length"} + headers: Dict[str, str] = { + key: value + for key, value in request.headers.items() + if key.lower() not in excluded + } + headers.setdefault("accept", "text/event-stream") + headers.setdefault("content-type", "application/json") + return headers + + +async def _proxy_dialog_engine_stream(request: Request, path: str) -> StreamingResponse: + """Generic helper to proxy SSE POST endpoints to dialog-engine.""" + target_url = f"{_dialog_engine_http_base()}{path}" + body = await request.body() + if not body: + body = b"{}" + + headers = _build_forward_headers(request) + + client = httpx.AsyncClient(timeout=SSE_TIMEOUT, follow_redirects=False) + try: + upstream_response = await client.send( + client.build_request( + "POST", + target_url, + content=body, + headers=headers, + ), + stream=True, + ) + except httpx.RequestError as exc: + await client.aclose() + raise HTTPException(status_code=502, detail=f"dialog-engine unreachable: {exc}") from exc + + if upstream_response.status_code >= 400: + detail_bytes = await upstream_response.aread() + await upstream_response.aclose() + await client.aclose() + detail = detail_bytes.decode("utf-8", errors="replace") or upstream_response.reason_phrase + return JSONResponse( + {"error": "dialog_engine_error", "detail": detail}, + status_code=upstream_response.status_code, + ) + + async def event_stream(): + try: + async for chunk in upstream_response.aiter_raw(): + if chunk: + yield chunk + finally: + await upstream_response.aclose() + await client.aclose() + + response_headers = {} + for header_name in ("cache-control", "content-language"): + if header_name in upstream_response.headers: + response_headers[header_name] = upstream_response.headers[header_name] + + media_type = upstream_response.headers.get("content-type", "text/event-stream") + return StreamingResponse( + event_stream(), + status_code=upstream_response.status_code, + media_type=media_type, + headers=response_headers, + ) + + +@app.post("/chat/stream") +async def proxy_chat_stream(request: Request): + """Proxy SSE chat stream to dialog-engine.""" + return await _proxy_dialog_engine_stream(request, "/chat/stream") + + +@app.post("/chat/audio/stream") +async def proxy_chat_audio_stream(request: Request): + """Proxy SSE audio stream to dialog-engine.""" + return await _proxy_dialog_engine_stream(request, "/chat/audio/stream") + + @app.post("/control/stop") async def control_stop_proxy(payload: Dict[str, str]): """Proxy STOP control to output-handler's /control/stop. diff --git a/services/gateway-python/tests/unit/test_asr_route.py b/services/gateway-python/tests/unit/test_asr_route.py index 4c54573..62c1c2e 100644 --- a/services/gateway-python/tests/unit/test_asr_route.py +++ b/services/gateway-python/tests/unit/test_asr_route.py @@ -1,8 +1,6 @@ # 注意:测试在模块目录下运行:cd services/gateway-python && pytest +import base64 import importlib -import json -import os -import types import pytest from fastapi.testclient import TestClient @@ -26,35 +24,41 @@ def test_asr_route_requires_absolute_path(client, monkeypatch): resp = client.post("/api/asr", json={"path": "relative.wav"}) assert resp.status_code == 400 data = resp.json() - assert "path must be absolute" in data.get("error", "") + assert data.get("error") == "path_must_be_absolute" -def test_asr_route_pushes_to_redis_list(client, monkeypatch): - # 准备一个假的 Redis 客户端来捕获 lpush 调用 - pushed = [] +def test_asr_route_reads_file_and_invokes_dialog_engine(client, monkeypatch, tmp_path): + asr_routes = importlib.import_module("src.services.asr_routes") - class DummyRedis: - async def lpush(self, queue, message): - pushed.append((queue, message)) + audio_path = tmp_path / "sample.wav" + audio_path.write_bytes(b"audio-bytes") - # 替换 asr_routes.get_redis 返回 DummyRedis - asr_routes = importlib.import_module("src.services.asr_routes") - monkeypatch.setattr(asr_routes, "get_redis", lambda: DummyRedis()) + captured_payload = {} + + async def fake_invoke(payload): + captured_payload.update(payload) + return {"sessionId": payload["sessionId"], "reply": "hi", "transcript": "test"} + + monkeypatch.setattr(asr_routes, "_invoke_dialog_engine", fake_invoke) + + resp = client.post( + "/api/asr", + json={ + "path": str(audio_path), + "sessionId": "sess-1", + "contentType": "audio/wav", + "options": {"lang": "zh"}, + }, + ) - abs_path = "/tmp/file.wav" - resp = client.post("/api/asr", json={"path": abs_path, "options": {"lang": "zh"}}) assert resp.status_code == 200 data = resp.json() - assert "task_id" in data - - # 验证写入了正确的队列与消息格式 - assert len(pushed) == 1 - queue, message = pushed[0] - assert queue == os.environ.get("ASR_TASKS_QUEUE", "asr_tasks") - - msg = json.loads(message) - assert msg["audio"]["type"] == "file" - assert msg["audio"]["path"] == abs_path - assert msg["audio"]["format"] == "wav" - assert msg["options"]["lang"] == "zh" - assert msg["meta"]["source"] == "gateway" + assert data["reply"] == "hi" + assert data["sessionId"] == "sess-1" + + assert captured_payload["sessionId"] == "sess-1" + assert captured_payload["contentType"] == "audio/wav" + assert captured_payload["lang"] == "zh" + + decoded_audio = base64.b64decode(captured_payload["audio"]) + assert decoded_audio == b"audio-bytes" diff --git a/services/gateway-python/tests/unit/test_sse_routes.py b/services/gateway-python/tests/unit/test_sse_routes.py new file mode 100644 index 0000000..24d27ea --- /dev/null +++ b/services/gateway-python/tests/unit/test_sse_routes.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +import json +from typing import Any, Callable, Dict, List + +import httpx +import pytest +from fastapi.testclient import TestClient + +import main + + +class FakeResponse: + def __init__( + self, + *, + status_code: int = 200, + headers: Dict[str, str] | None = None, + chunks: List[bytes] | None = None, + detail_text: str | None = None, + ) -> None: + self.status_code = status_code + self.headers = headers or {"content-type": "text/event-stream"} + self._chunks = list(chunks or ([] if detail_text else [])) + if detail_text and not self._chunks: + self._chunks.append(detail_text.encode("utf-8")) + self.reason_phrase = "OK" + self.closed = False + + async def aiter_raw(self): + for chunk in self._chunks: + yield chunk + + async def aread(self) -> bytes: + return b"".join(self._chunks) + + async def aclose(self) -> None: + self.closed = True + + +class DummyAsyncClient: + def __init__(self, response_factory: Callable[[httpx.Request], FakeResponse]) -> None: + self._response_factory = response_factory + self.last_request: httpx.Request | None = None + self.closed = False + self._response: FakeResponse | None = None + + def build_request(self, method: str, url: str, *, headers: Dict[str, str], content: bytes) -> httpx.Request: + self.last_request = httpx.Request(method, url, headers=headers, content=content) + return self.last_request + + async def send(self, request: httpx.Request, stream: bool = False) -> FakeResponse: + self.last_request = request + self._response = self._response_factory(request) + return self._response + + async def aclose(self) -> None: + self.closed = True + + +@pytest.fixture +def client(): + with TestClient(main.app) as test_client: + yield test_client + + +@pytest.mark.parametrize( + ("endpoint", "expected_path"), + [ + ("/chat/stream", "/chat/stream"), + ("/chat/audio/stream", "/chat/audio/stream"), + ], +) +def test_sse_proxy_success(monkeypatch, client: TestClient, endpoint: str, expected_path: str) -> None: + chunks = [b"event: text-delta\n", b"data: hello\n\n"] + captured: Dict[str, Any] = {} + + def factory(request: httpx.Request) -> FakeResponse: + captured["url"] = str(request.url) + captured["headers"] = dict(request.headers) + captured["content"] = request.content + return FakeResponse( + status_code=200, + headers={"content-type": "text/event-stream", "cache-control": "no-cache"}, + chunks=chunks, + ) + + created_clients: List[DummyAsyncClient] = [] + + def fake_async_client(*args, **kwargs): + client_instance = DummyAsyncClient(factory) + created_clients.append(client_instance) + return client_instance + + monkeypatch.setattr(main, "httpx", main.httpx) + monkeypatch.setattr(main.httpx, "AsyncClient", fake_async_client) + + response = client.post(endpoint, json={"sessionId": "sess"}, headers={"X-Test": "1"}) + + assert response.status_code == 200 + assert response.headers["content-type"].startswith("text/event-stream") + assert response.content == b"".join(chunks) + + assert captured["url"] == f"{main.DIALOG_ENGINE_URL.rstrip('/')}{expected_path}" + assert captured["headers"]["x-test"] == "1" + assert captured["headers"]["content-type"] == "application/json" + parsed_body = json.loads(captured["content"].decode("utf-8")) + assert parsed_body == {"sessionId": "sess"} + + client_instance = created_clients[0] + assert client_instance.closed is True + assert client_instance._response is not None + assert client_instance._response.closed is True + + +def test_sse_proxy_error(monkeypatch, client: TestClient) -> None: + def factory(request: httpx.Request) -> FakeResponse: + return FakeResponse(status_code=503, headers={"content-type": "text/event-stream"}, detail_text="fail") + + created_clients: List[DummyAsyncClient] = [] + + def fake_async_client(*args, **kwargs): + instance = DummyAsyncClient(factory) + created_clients.append(instance) + return instance + + monkeypatch.setattr(main, "httpx", main.httpx) + monkeypatch.setattr(main.httpx, "AsyncClient", fake_async_client) + + response = client.post("/chat/stream", json={"sessionId": "s"}) + + assert response.status_code == 503 + assert response.json() == {"error": "dialog_engine_error", "detail": "fail"} + + client_instance = created_clients[0] + assert client_instance.closed is True + assert client_instance._response is not None + assert client_instance._response.closed is True