diff --git a/.env.example b/.env.example index 0c1b95d..5e14c20 100644 --- a/.env.example +++ b/.env.example @@ -56,6 +56,14 @@ ASR_WHISPER_DEVICE=auto ASR_WHISPER_COMPUTE_TYPE=int8 ASR_WHISPER_BEAM_SIZE=1 ASR_WHISPER_CACHE_DIR= +# Volcengine streaming ASR (必填凭证;启用火山引擎 ASR 时需设置) +ASR_VOLC_ENDPOINT=wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_async +ASR_VOLC_APP_KEY= +ASR_VOLC_ACCESS_KEY= +ASR_VOLC_RESOURCE_ID= +ASR_VOLC_CONNECT_ID_PREFIX=dialog-engine +ASR_VOLC_TIMEOUT_SECONDS=15 +ASR_VOLC_FAILOVER_THRESHOLD=5 # 可选:自定义 OpenAI 终端或组织 OPENAI_BASE_URL= diff --git a/demo/sauc_python/readme.md b/demo/sauc_python/readme.md new file mode 100644 index 0000000..4dbcebd --- /dev/null +++ b/demo/sauc_python/readme.md @@ -0,0 +1,15 @@ +# README + +**asr tob 相关client demo** + +# Notice +python version: python 3.x + +替换代码中的key为真实数据: + "app_key": "xxxxxxx", + "access_key": "xxxxxxxxxxxxxxxx" +使用示例: + python3 sauc_websocket_demo.py --file /Users/bytedance/code/python/eng_ddc_itn.wav + + + diff --git a/demo/sauc_python/sauc_websocket_demo.py b/demo/sauc_python/sauc_websocket_demo.py new file mode 100644 index 0000000..842fb54 --- /dev/null +++ b/demo/sauc_python/sauc_websocket_demo.py @@ -0,0 +1,556 @@ +import asyncio +import aiohttp +import json +import struct +import gzip +import uuid +import logging +import os +import subprocess +import io +from typing import Optional, List, Dict, Any, Tuple, AsyncGenerator + +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler('run.log'), + logging.StreamHandler() + ] +) +logger = logging.getLogger(__name__) + +# 常量定义 +DEFAULT_SAMPLE_RATE = 16000 + +class ProtocolVersion: + V1 = 0b0001 + +class MessageType: + CLIENT_FULL_REQUEST = 0b0001 + CLIENT_AUDIO_ONLY_REQUEST = 0b0010 + SERVER_FULL_RESPONSE = 0b1001 + SERVER_ERROR_RESPONSE = 0b1111 + +class MessageTypeSpecificFlags: + NO_SEQUENCE = 0b0000 + POS_SEQUENCE = 0b0001 + NEG_SEQUENCE = 0b0010 + NEG_WITH_SEQUENCE = 0b0011 + +class SerializationType: + NO_SERIALIZATION = 0b0000 + JSON = 0b0001 + +class CompressionType: + GZIP = 0b0001 + + +class Config: + def __init__(self): + # 从环境变量读取,找不到再退回占位符 + app_key = os.getenv("ASR_VOLC_APP_KEY", "xxxxxxx") + access_key = os.getenv("ASR_VOLC_ACCESS_KEY", "xxxxxxxxxxxx") + resource_id = os.getenv("ASR_VOLC_RESOURCE_ID", "volc.bigasr.sauc.duration") + self.auth = { + "app_key": app_key, + "access_key": access_key, + "resource_id": resource_id, + } + + @property + def app_key(self) -> str: + return self.auth["app_key"] + + @property + def access_key(self) -> str: + return self.auth["access_key"] + + @property + def resource_id(self) -> str: + return self.auth["resource_id"] + +config = Config() + +class CommonUtils: + @staticmethod + def gzip_compress(data: bytes) -> bytes: + return gzip.compress(data) + + @staticmethod + def gzip_decompress(data: bytes) -> bytes: + return gzip.decompress(data) + + @staticmethod + def judge_wav(data: bytes) -> bool: + if len(data) < 44: + return False + return data[:4] == b'RIFF' and data[8:12] == b'WAVE' + + @staticmethod + def convert_wav_with_path(audio_path: str, sample_rate: int = DEFAULT_SAMPLE_RATE) -> bytes: + try: + cmd = [ + "ffmpeg", "-v", "quiet", "-y", "-i", audio_path, + "-acodec", "pcm_s16le", "-ac", "1", "-ar", str(sample_rate), + "-f", "wav", "-" + ] + result = subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + return result.stdout + except FileNotFoundError: + # 无 ffmpeg,尝试 PyAV 方案 + try: + import av + from av.audio.resampler import AudioResampler + import wave + + container = av.open(audio_path, mode='r') + audio_stream = next((s for s in container.streams if s.type == 'audio'), None) + if audio_stream is None: + raise RuntimeError('No audio stream found') + + resampler = AudioResampler(format='s16', layout='mono', rate=sample_rate) + pcm_bytes = bytearray() + for packet in container.demux(audio_stream): + for frame in packet.decode(): + for rframe in resampler.resample(frame): + arr = rframe.to_ndarray() + pcm_bytes.extend(arr.tobytes(order='C')) + + # 写成 WAV 到内存 + buf = io.BytesIO() + with wave.open(buf, 'wb') as w: + w.setnchannels(1) + w.setsampwidth(2) + w.setframerate(sample_rate) + w.writeframes(bytes(pcm_bytes)) + return buf.getvalue() + except Exception as e: + logger.error(f"PyAV conversion failed: {e}") + raise RuntimeError(f"Audio conversion failed via PyAV: {e}") + except subprocess.CalledProcessError as e: + logger.error(f"FFmpeg conversion failed: {e.stderr.decode()}") + raise RuntimeError(f"Audio conversion failed: {e.stderr.decode()}") + + @staticmethod + def read_wav_info(data: bytes) -> Tuple[int, int, int, int, bytes]: + if len(data) < 44: + raise ValueError("Invalid WAV file: too short") + + # 解析WAV头 + chunk_id = data[:4] + if chunk_id != b'RIFF': + raise ValueError("Invalid WAV file: not RIFF format") + + format_ = data[8:12] + if format_ != b'WAVE': + raise ValueError("Invalid WAV file: not WAVE format") + + # 解析fmt子块 + audio_format = struct.unpack(' 'AsrRequestHeader': + self.message_type = message_type + return self + + def with_message_type_specific_flags(self, flags: int) -> 'AsrRequestHeader': + self.message_type_specific_flags = flags + return self + + def with_serialization_type(self, serialization_type: int) -> 'AsrRequestHeader': + self.serialization_type = serialization_type + return self + + def with_compression_type(self, compression_type: int) -> 'AsrRequestHeader': + self.compression_type = compression_type + return self + + def with_reserved_data(self, reserved_data: bytes) -> 'AsrRequestHeader': + self.reserved_data = reserved_data + return self + + def to_bytes(self) -> bytes: + header = bytearray() + header.append((ProtocolVersion.V1 << 4) | 1) + header.append((self.message_type << 4) | self.message_type_specific_flags) + header.append((self.serialization_type << 4) | self.compression_type) + header.extend(self.reserved_data) + return bytes(header) + + @staticmethod + def default_header() -> 'AsrRequestHeader': + return AsrRequestHeader() + +class RequestBuilder: + @staticmethod + def new_auth_headers() -> Dict[str, str]: + reqid = str(uuid.uuid4()) + return { + "X-Api-Resource-Id": config.resource_id, + "X-Api-Request-Id": reqid, + "X-Api-Access-Key": config.access_key, + "X-Api-App-Key": config.app_key + } + + @staticmethod + def new_full_client_request(seq: int) -> bytes: # 添加seq参数 + header = AsrRequestHeader.default_header() \ + .with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE) + + payload = { + "user": { + "uid": "demo_uid" + }, + "audio": { + "format": "wav", + "codec": "raw", + "rate": 16000, + "bits": 16, + "channel": 1 + }, + "request": { + "model_name": "bigmodel", + "enable_itn": True, + "enable_punc": True, + "enable_ddc": True, + "show_utterances": True, + "enable_nonstream": False + } + } + + payload_bytes = json.dumps(payload).encode('utf-8') + compressed_payload = CommonUtils.gzip_compress(payload_bytes) + payload_size = len(compressed_payload) + + request = bytearray() + request.extend(header.to_bytes()) + request.extend(struct.pack('>i', seq)) # 使用传入的seq + request.extend(struct.pack('>I', payload_size)) + request.extend(compressed_payload) + + return bytes(request) + + @staticmethod + def new_audio_only_request(seq: int, segment: bytes, is_last: bool = False) -> bytes: + header = AsrRequestHeader.default_header() + if is_last: # 最后一个包特殊处理 + header.with_message_type_specific_flags(MessageTypeSpecificFlags.NEG_WITH_SEQUENCE) + seq = -seq # 设为负值 + else: + header.with_message_type_specific_flags(MessageTypeSpecificFlags.POS_SEQUENCE) + header.with_message_type(MessageType.CLIENT_AUDIO_ONLY_REQUEST) + + request = bytearray() + request.extend(header.to_bytes()) + request.extend(struct.pack('>i', seq)) + + compressed_segment = CommonUtils.gzip_compress(segment) + request.extend(struct.pack('>I', len(compressed_segment))) + request.extend(compressed_segment) + + return bytes(request) + +class AsrResponse: + def __init__(self): + self.code = 0 + self.event = 0 + self.is_last_package = False + self.payload_sequence = 0 + self.payload_size = 0 + self.payload_msg = None + + def to_dict(self) -> Dict[str, Any]: + return { + "code": self.code, + "event": self.event, + "is_last_package": self.is_last_package, + "payload_sequence": self.payload_sequence, + "payload_size": self.payload_size, + "payload_msg": self.payload_msg + } + +class ResponseParser: + @staticmethod + def parse_response(msg: bytes) -> AsrResponse: + response = AsrResponse() + + header_size = msg[0] & 0x0f + message_type = msg[1] >> 4 + message_type_specific_flags = msg[1] & 0x0f + serialization_method = msg[2] >> 4 + message_compression = msg[2] & 0x0f + + payload = msg[header_size*4:] + + # 解析message_type_specific_flags + if message_type_specific_flags & 0x01: + response.payload_sequence = struct.unpack('>i', payload[:4])[0] + payload = payload[4:] + if message_type_specific_flags & 0x02: + response.is_last_package = True + if message_type_specific_flags & 0x04: + response.event = struct.unpack('>i', payload[:4])[0] + payload = payload[4:] + + # 解析message_type + if message_type == MessageType.SERVER_FULL_RESPONSE: + response.payload_size = struct.unpack('>I', payload[:4])[0] + payload = payload[4:] + elif message_type == MessageType.SERVER_ERROR_RESPONSE: + response.code = struct.unpack('>i', payload[:4])[0] + response.payload_size = struct.unpack('>I', payload[4:8])[0] + payload = payload[8:] + + if not payload: + return response + + # 解压缩 + if message_compression == CompressionType.GZIP: + try: + payload = CommonUtils.gzip_decompress(payload) + except Exception as e: + logger.error(f"Failed to decompress payload: {e}") + return response + + # 解析payload + try: + if serialization_method == SerializationType.JSON: + response.payload_msg = json.loads(payload.decode('utf-8')) + except Exception as e: + logger.error(f"Failed to parse payload: {e}") + + return response + +class AsrWsClient: + def __init__(self, url: str, segment_duration: int = 200): + self.seq = 1 + self.url = url + self.segment_duration = segment_duration + self.conn = None + self.session = None # 添加session引用 + + async def __aenter__(self): + self.session = aiohttp.ClientSession() + return self + + async def __aexit__(self, exc_type, exc, tb): + if self.conn and not self.conn.closed: + await self.conn.close() + if self.session and not self.session.closed: + await self.session.close() + + async def read_audio_data(self, file_path: str) -> bytes: + try: + with open(file_path, 'rb') as f: + content = f.read() + + if not CommonUtils.judge_wav(content): + logger.info("Converting audio to WAV format...") + content = CommonUtils.convert_wav_with_path(file_path, DEFAULT_SAMPLE_RATE) + + return content + except Exception as e: + logger.error(f"Failed to read audio data: {e}") + raise + + def get_segment_size(self, content: bytes) -> int: + try: + channel_num, samp_width, frame_rate, _, _ = CommonUtils.read_wav_info(content)[:5] + size_per_sec = channel_num * samp_width * frame_rate + segment_size = size_per_sec * self.segment_duration // 1000 + return segment_size + except Exception as e: + logger.error(f"Failed to calculate segment size: {e}") + raise + + async def create_connection(self) -> None: + headers = RequestBuilder.new_auth_headers() + try: + self.conn = await self.session.ws_connect( # 使用self.session + self.url, + headers=headers + ) + logger.info(f"Connected to {self.url}") + except Exception as e: + logger.error(f"Failed to connect to WebSocket: {e}") + raise + + async def send_full_client_request(self) -> None: + request = RequestBuilder.new_full_client_request(self.seq) + self.seq += 1 # 发送后递增 + try: + await self.conn.send_bytes(request) + logger.info(f"Sent full client request with seq: {self.seq-1}") + + msg = await self.conn.receive() + if msg.type == aiohttp.WSMsgType.BINARY: + response = ResponseParser.parse_response(msg.data) + logger.info(f"Received response: {response.to_dict()}") + else: + logger.error(f"Unexpected message type: {msg.type}") + except Exception as e: + logger.error(f"Failed to send full client request: {e}") + raise + + async def send_messages(self, segment_size: int, content: bytes) -> AsyncGenerator[None, None]: + audio_segments = self.split_audio(content, segment_size) + total_segments = len(audio_segments) + + for i, segment in enumerate(audio_segments): + is_last = (i == total_segments - 1) + request = RequestBuilder.new_audio_only_request( + self.seq, + segment, + is_last=is_last + ) + await self.conn.send_bytes(request) + logger.info(f"Sent audio segment with seq: {self.seq} (last: {is_last})") + + if not is_last: + self.seq += 1 + + await asyncio.sleep(self.segment_duration / 1000) # 逐个发送,间隔时间模拟实时流 + # 让出控制权,允许接受消息 + yield + + async def recv_messages(self) -> AsyncGenerator[AsrResponse, None]: + try: + async for msg in self.conn: + if msg.type == aiohttp.WSMsgType.BINARY: + response = ResponseParser.parse_response(msg.data) + yield response + + if response.is_last_package or response.code != 0: + break + elif msg.type == aiohttp.WSMsgType.ERROR: + logger.error(f"WebSocket error: {msg.data}") + break + elif msg.type == aiohttp.WSMsgType.CLOSED: + logger.info("WebSocket connection closed") + break + except Exception as e: + logger.error(f"Error receiving messages: {e}") + raise + + async def start_audio_stream(self, segment_size: int, content: bytes) -> AsyncGenerator[AsrResponse, None]: + async def sender(): + async for _ in self.send_messages(segment_size, content): + pass + + # 启动发送和接收任务 + sender_task = asyncio.create_task(sender()) + + try: + async for response in self.recv_messages(): + yield response + finally: + sender_task.cancel() + try: + await sender_task + except asyncio.CancelledError: + pass + + @staticmethod + def split_audio(data: bytes, segment_size: int) -> List[bytes]: + if segment_size <= 0: + return [] + + segments = [] + for i in range(0, len(data), segment_size): + end = i + segment_size + if end > len(data): + end = len(data) + segments.append(data[i:end]) + return segments + + async def execute(self, file_path: str) -> AsyncGenerator[AsrResponse, None]: + if not file_path: + raise ValueError("File path is empty") + + if not self.url: + raise ValueError("URL is empty") + + self.seq = 1 + + try: + # 1. 读取音频文件 + content = await self.read_audio_data(file_path) + + # 2. 计算分段大小 + segment_size = self.get_segment_size(content) + + # 3. 创建WebSocket连接 + await self.create_connection() + + # 4. 发送完整客户端请求 + await self.send_full_client_request() + + # 5. 启动音频流处理 + async for response in self.start_audio_stream(segment_size, content): + yield response + + except Exception as e: + logger.error(f"Error in ASR execution: {e}") + raise + finally: + if self.conn: + await self.conn.close() + +async def main(): + import argparse + + parser = argparse.ArgumentParser(description="ASR WebSocket Client") + parser.add_argument("--file", type=str, required=True, help="Audio file path") + + #wss://openspeech.bytedance.com/api/v3/sauc/bigmodel + #wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_async + #wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_nostream + parser.add_argument("--url", type=str, default=os.getenv("ASR_VOLC_ENDPOINT", "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_async"), + help="WebSocket URL") + parser.add_argument("--seg-duration", type=int, default=200, + help="Audio duration(ms) per packet, default:200") + + args = parser.parse_args() + + async with AsrWsClient(args.url, args.seg_duration) as client: # 使用async with + try: + async for response in client.execute(args.file): + logger.info(f"Received response: {json.dumps(response.to_dict(), indent=2, ensure_ascii=False)}") + except Exception as e: + logger.error(f"ASR processing failed: {e}") + +if __name__ == "__main__": + asyncio.run(main()) + + # 用法: + # python3 sauc_websocket_demo.py --file /Users/bytedance/code/python/eng_ddc_itn.wav diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml index fab630f..070654e 100644 --- a/docker-compose.dev.yml +++ b/docker-compose.dev.yml @@ -43,6 +43,7 @@ services: dockerfile: Dockerfile container_name: aivtuber-dialog-engine-dev environment: + - LOG_LEVEL=DEBUG - PORT=8100 - PYTHONPATH=/app/src - REDIS_HOST=${REDIS_HOST:-redis} @@ -85,6 +86,26 @@ services: - LLM_REQUEST_TIMEOUT=${LLM_REQUEST_TIMEOUT:-30} - LLM_RETRY_LIMIT=${LLM_RETRY_LIMIT:-2} - LLM_RETRY_BACKOFF_SECONDS=${LLM_RETRY_BACKOFF_SECONDS:-0.5} + # ASR configuration + - ASR_ENABLED=${ASR_ENABLED:-true} + - ASR_PROVIDER=${ASR_PROVIDER:-mock} + - ASR_MAX_BYTES=${ASR_MAX_BYTES:-5242880} + - ASR_MAX_DURATION_SECONDS=${ASR_MAX_DURATION_SECONDS:-300} + - ASR_TARGET_SAMPLE_RATE=${ASR_TARGET_SAMPLE_RATE:-16000} + - ASR_TARGET_CHANNELS=${ASR_TARGET_CHANNELS:-1} + - ASR_DEFAULT_LANG=${ASR_DEFAULT_LANG:-} + - ASR_WHISPER_MODEL=${ASR_WHISPER_MODEL:-base} + - ASR_WHISPER_DEVICE=${ASR_WHISPER_DEVICE:-auto} + - ASR_WHISPER_COMPUTE_TYPE=${ASR_WHISPER_COMPUTE_TYPE:-int8} + - ASR_WHISPER_BEAM_SIZE=${ASR_WHISPER_BEAM_SIZE:-1} + - ASR_WHISPER_CACHE_DIR=${ASR_WHISPER_CACHE_DIR:-} + - ASR_VOLC_ENDPOINT=${ASR_VOLC_ENDPOINT:-wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_async} + - ASR_VOLC_APP_KEY=${ASR_VOLC_APP_KEY:-} + - ASR_VOLC_ACCESS_KEY=${ASR_VOLC_ACCESS_KEY:-} + - ASR_VOLC_RESOURCE_ID=${ASR_VOLC_RESOURCE_ID:-} + - ASR_VOLC_CONNECT_ID_PREFIX=${ASR_VOLC_CONNECT_ID_PREFIX:-dialog-engine} + - ASR_VOLC_TIMEOUT_SECONDS=${ASR_VOLC_TIMEOUT_SECONDS:-15} + - ASR_VOLC_FAILOVER_THRESHOLD=${ASR_VOLC_FAILOVER_THRESHOLD:-5} # Internal state management (AI emotions & affinity) - INTERNAL_STATE_DB_PATH=${INTERNAL_STATE_DB_PATH:-/app/data/internal_states.db} ports: diff --git a/docs/python-environment-guidelines.md b/docs/python-environment-guidelines.md new file mode 100644 index 0000000..ba83397 --- /dev/null +++ b/docs/python-environment-guidelines.md @@ -0,0 +1,36 @@ +# Python 环境与依赖指引 + +## 推荐 Python 版本 +- **Python 3.11.x**:目前项目在本地与 CI/测试中均使用 Python 3.11.14,通过 `uv python install 3.11` 可以快速获取该版本。 +- 历史上部分开发机可能仍是系统 Python 3.9;建议迁移到 3.11,避免类型提示(PEP 604 等)和依赖冲突问题。 + +## 工具链建议 +- **uv**(https://github.com/astral-sh/uv)作为默认包管理与虚拟环境工具: + - 创建虚拟环境:`uv venv --python 3.11 .venv` + - 安装依赖(示例):`uv pip install --python .venv/bin/python -r services/gateway-python/requirements.txt` + - 查看已装包:`uv pip list --python .venv/bin/python` +- 如需传统 `pip`,可在 `.venv/bin/python -m pip install ...` 中使用,但 uv 已封装常用场景。 + +## 依赖分层 +- **根目录**: + - `requirements-dev.txt`:通用开发工具(lint、format、测试等)。 + - 运行全局脚本或多服务开发时,请在 `.venv` 中一次性安装:`uv pip install --python .venv/bin/python -r requirements-dev.txt` +- **各微服务**: + - `services//requirements.txt`:服务运行时依赖。 + - `services//requirements-test.txt`:该服务的测试/CI 依赖。 + - 示例(Gateway): + ```bash + uv pip install --python .venv/bin/python -r services/gateway-python/requirements.txt + uv pip install --python .venv/bin/python -r services/gateway-python/requirements-test.txt + ``` +- **前端**: + - `front_end/package.json` 与 `package-lock.json` 由 npm 管理;建议使用 `npm install`(或 `pnpm` 视团队约定而定)。 + +## 虚拟环境规范 +- 默认在仓库根目录创建 `.venv/`,便于脚本、IDE、CI 定位。 +- 激活方式:`source .venv/bin/activate`(macOS/Linux)或使用 `uv run`/`.venv/bin/python` 直接调用。 +- 不要将 `.venv` 提交到版本库;`.gitignore` 已默认忽略。 + +## 版本固定与后续工作 +- 各服务的 `requirements*.txt` 目前以浮动上限为主,建议逐步改为锁定范围(例如 `~=`, `<`)。 +- 后续可引入 `pyproject.toml` 与 uv 的依赖组管理,统一锁定版本并生成 `uv.lock`;此文档作为当前状态的快速参考。 diff --git a/docs/streaming-asr-implementation-plan.md b/docs/streaming-asr-implementation-plan.md new file mode 100644 index 0000000..8bef68c --- /dev/null +++ b/docs/streaming-asr-implementation-plan.md @@ -0,0 +1,65 @@ +# 流式 ASR 改造计划 + +## 目标 +- 前端录音时实时推送音频块至后端,避免“录完再上传”的延迟。 +- 在 `dialog-engine` 中接入真正的流式 ASR provider,让 `/chat/audio/stream` 能在音频到达时立即输出部分转写。 +- 保留现有批处理识别路径,便于回退或作为离线兜底方案。 + +## 现状差距 +- 前端 `MediaRecorder` 录音后一次性上传,缺乏实时发送和恢复机制。 +- 后端 ASR 服务以完整 `AudioBundle` 为单位,仅提供 `transcribe_bundle`。 +- SSE 层只是按顺序回放已完成的 partial,没有真正的增量推理。 + +## 工作流拆解 + +### 1. 前端实时采集与上传 +- 调整 `useApi.js`:`ondataavailable` 触发时即刻通过 WebSocket 发送音频块,维护 chunk 序号与会话状态。 +- 处理网络抖动与回压:必要时增加缓冲区、重传/丢弃策略,并在 UI 上反馈录音状态。 +- 与后端协商编码格式(例如 `audio/webm;codecs=opus` 或 PCM),并提供能力检测与降级路径。 +- 更新界面提示,包括实时录音、断线重连、错误提示等。 + +### 2. 网关/入口协议 +- 明确流式音频入站协议:WebSocket(当前基础)、HTTP chunked 或 gRPC,保持消息结构一致(开始/数据/结束)。 +- 在 gateway 层对接流式音频:转发 chunk、维护会话超时、处理取消/失败并通知前端。 +- 确保安全控制与限速,避免恶意长连接占用。 + +### 3. Dialog Engine 流式 ASR +- 扩展 `AsrProvider`:将 `stream()` 作为一等接口,允许 provider 返回增量 `AsrPartial`。 +- 引入支持流式推理的实现(如 Whisper streaming wrapper、VAD+增量解码),管理模型加载与资源。 +- 重构 `/chat/audio/stream`:在音频块到达时驱动 provider stream,并立即通过 SSE 推送 `asr-partial` / `asr-final`;同时保留 LLM 回复流。 +- 兼容现有 `transcribe_bundle`:根据配置或输入模式动态选择流式/批处理路径。 + +**当前进展**: +- 新增 `VolcengineAsrProvider`,基于豆包语音 WebSocket API 实现分片上传和增量转写,支持 `ASR_VOLC_*` 环境变量配置凭证。 +- `AsrProvider` / `AsrService` 支持真正的流式生成,`chat_audio_stream` 在 SSE 中实时发送 `asr-partial` 事件后再进入 LLM 回复阶段。 +- 为向后兼容,`transcribe_bundle` 仍可返回批量结果,且在被 monkeypatch 时作为流式路径的最终兜底。 +- `input-handler` 通过 `/chat/audio/stream` 建立 SSE,边上传音频边转发 `asr-partial`、`text-delta` 事件到 Redis,`output-handler` 与前端 WebSocket 即时展示转写与回复增量。 +- 火山引擎凭证通过 `ASR_VOLC_APP_KEY/ACCESS_KEY/RESOURCE_ID` 环境变量注入,服务端记录 `X-Tt-Logid` 并在日志/统计中追加 `asr.volcengine.latency_ms`、`error_code` 等信息,连续失败后自动回退本地 Mock ASR。 + +### 4. 基础设施与性能 +- 评估流式模型的 CPU/GPU 占用,调整部署拓扑、容器资源、自动扩缩容策略。 +- 建立可观测性:记录每个会话的 chunk 处理时延、缓冲深度、ASR 延迟等指标。 +- 优化 Redis/消息队列等依赖,确保流式事件不会阻塞现有通路。 + +### 5. 测试与质量保障 +- 单元测试:覆盖新的 provider stream 逻辑、音频块协议解析、错误分支。 +- 集成测试:模拟真实 chunk 流(例如使用 pytest-asyncio + WebSocket 客户端)验证端到端延迟与鲁棒性。 +- 前端 E2E:利用 Playwright/Cypress 验证不同浏览器、网络条件下的录音与实时转写体验。 +- 手工验收清单:双语测试、长时会话、断网恢复、设备权限处理等。 + +### 6. 发布与运维 +- 通过特性开关(如 `STREAMING_ASR_ENABLED`)控制流式路径,支持灰度发布。 +- 制定分阶段上线计划(开发 → 测试 → 预发布 → 生产),每阶段监控核心指标并准备回滚方案。 +- 更新文档:README、运维手册、故障排查指南;同步团队培训与支持流程。 + +## 时间与优先级建议 +1. **基础设施准备(前端上传 + 网关协议)**:解决数据流动路径,是开启流式的前提。 +2. **ASR provider 与服务改造**:实现实时转写核心能力。 +3. **观测与测试强化**:确保稳定性与可维护性。 +4. **灰度上线与优化**:在真实流量下验证表现,迭代调优。 + +## 里程碑检查 +- ✅ 可在本地环境完成端到端实时转写示例。 +- ✅ SSE 客户端在 500ms 内收到第一条 `asr-partial`。 +- ✅ 监控面板显示每条会话的 ASR 延迟与错误率。 +- ✅ 回退机制验证通过(关闭特性开关恢复到批处理模式)。 diff --git a/docs/tts-audio-debug-plan.md b/docs/tts-audio-debug-plan.md new file mode 100644 index 0000000..e7e2f48 --- /dev/null +++ b/docs/tts-audio-debug-plan.md @@ -0,0 +1,56 @@ +# TTS Audio Delivery Debug Plan + +我们目前锁定了两个可能的根因: +1. Output Handler 没有把 ingest WS 收到的音频 chunk 转发给前端 /ws/output。 +2. Dialog-engine 在 SSE 还没结束前就提前关闭 HTTP 连接(`peer closed connection`),导致 Output Handler 无法从 text 流切换到音频模式。 + +下面的计划用于缩小范围并验证修复。 + +## 1. 基线信息 +- 关注任务 ID:通过前端 console 的 `Task ID assigned` 日志获得。 +- 使用以下 docker 服务日志: + - `docker logs aivtuber-dialog-engine-dev` + - `docker logs aivtuber-input-handler-dev` + - `docker logs aivtuber-output-handler-dev` + - `docker logs aivtuber-gateway-dev` +- 前端 console 日志查看:Chrome DevTools MCP `list_console_messages`。 + +## 2. 测试用例流程 +### 2.1 正常语音输入(前端) +1. 刷新前端,点击麦克风,录制 5 秒。 +2. 观察 console:记录 Task ID、`Audio is present` 等日志。 +3. 保持页面不刷新,等待音频播放事件或报错。 + +### 2.2 脚本模拟(可重复) +1. 运行 `scripts/...` 或自定义 Python 脚本,向 `/ws/input` 推送 demo/sauc_python/input.webm。 +2. 同时连接 `/ws/output/{task_id}`,捕获所有 text/binary 帧;确认是否收到 `audio_chunk` 元数据和二进制帧。 +3. 保存输出到本地文件,便于检查是否有实际音频数据。 + +## 3. 判定哪个环节出问题 +| 证据 | 结论 | +| --- | --- | +| Output Handler 日志中,成功收到 `type=audio_chunk` 元数据并发送 bytes,前端却无音频 | 前端或 gateway 反代处理有问题(少见) | +| Output Handler 日志始终停在 `status=error` (peer closed connection),且没有 `relay_speech_chunk` 字样 | dialog-engine SSE 断开 | +| Dialog-engine 日志显示 `SPEECH_CHUNK` 大量发送,但 Output Handler 没有收到 `relay_speech_chunk` | output handler 未保持 ingest WS 连接或未匹配 session_id | + +## 4. 修复策略 +### 4.1 如果是 Output Handler 缺转发 +- 检查 `relay_speech_chunk` 是否命中 `active_connections`。 +- 若 `active_connections` 中缺任务,可能是 output WebSocket 提前关闭 → 需延长等待逻辑。 +- 若存在,添加 debug:记录每个 chunk 的大小、发送成功与否。 +- 确认 `streaming_events` 触发,防止 WS 在 audio 完成前关闭。 + +### 4.2 如果是 Dialog-engine 提前断开 SSE +- 检查 `app.py` 中 `_schedule_tts` 是否抛异常;必要时捕获、记录。 +- 在 SSE `event_generator` 完整输出后再 `return`,确保 HTTP 流保持。 +- 若使用第三方 ASR/TTS,设置重试或延长 timeout。 + +## 5. 回归测试 +1. 前端麦克风录制场景:确保 console 出现 `Received audio chunk metadata`,Live2D 播放音频。 +2. Python 脚本回放:确认输出文件大小 > 0,内容可播放。 +3. 多次重复,确保没有 `peer closed connection`。 +4. 记录最终日志片段,归档到 docs/debug-log-*.txt 供后续参考。 + +## 6. 输出 +- 一旦定位根因,更新本文件并记录修复步骤。 +- 若调整 output handler,增加单元测试(mock Redis channel)验证音频流状态。 diff --git a/docs/volcengine-streaming-asr-notes.md b/docs/volcengine-streaming-asr-notes.md new file mode 100644 index 0000000..2640f6d --- /dev/null +++ b/docs/volcengine-streaming-asr-notes.md @@ -0,0 +1,102 @@ +# Volcengine 流式 ASR 接入笔记 + +在 `feat/real-streaming-asr` 分支上,我们为项目引入了豆包语音的大模型流式识别。这个过程踩了很多坑,下面把关键经验记录下来,方便后续排查和复用。 + +--- + +## 1. 官方 Demo vs. 我们的服务 + +| 项目 | 官方脚本(demo/sauc_websocket_demo.py) | 主链路(input-handler → dialog-engine) | +| --- | --- | --- | +| 音频位置 | 直接放在 demo 目录下,脚本读取本地文件 | 音频通过 WebSocket 上传到 input-handler,落在 `/tmp/aivtuber_tasks//input.webm` | +| 音频格式 | 脚本调用 PyAV/ffmpeg 转成 **WAV (PCM 16kHz mono s16)** 后再逐包发送 | 保存的是原始 `audio/webm;codecs=opus`,需要我们在服务端自行解码 | +| 协议实现 | 官方示例自带:全量请求 → gzip → 序列号 → 分片 | 我们初版仅发裸 PCM,未加 WAV 头 & gzip & 序列号,导致服务端只返回 `log_id` | + +> 结论:**要么像 Demo 一样发送经过封装的 WAV 分片,要么在服务端把 WebM/Opus 转成 16kHz PCM 后再接入 Volcengine。** + +--- + +## 2. 音频解码:PyAV + ffmpeg 缺一不可 + +libsoundfile 在很多 Linux 发行版中并不支持 WebM/Opus,所以最开始对付 input-handler 保存的 `input.webm` 时会直接抛出 “unsupported audio”。最终我们采用了层级解码策略: + +1. **PyAV 解码(优先)** + - `services/dialog-engine/src/dialog_engine/audio/preprocessor.py` 中使用 PyAV 将内存中的 WebM 直接解成 16k 单声道 PCM。 + - 每帧解码后统一 reshape 为 `(samples, 1)`,避免 numpy 维度不匹配。 +2. **ffmpeg 子进程兜底(若容器内存在 ffmpeg)** + - 调用 `ffmpeg -i pipe:0 -acodec pcm_f32le -ar 16000 -ac 1`,将输入转换成 float32 PCM。 +3. **再不行就按 PCM 解释** + - 将数据按 float32 或 int16 整体 reshape 成 `(N, 1)`,保证后续仍有可用的 PCM。 + +此外,为了让日志可观察,我们给音频预处理加上了 DEBUG 级别的提示: +``` +audio.preprocessor: using PyAV decoder ... +audio.preprocessor: ffmpeg decode failed ... +audio.preprocessor: falling back to raw PCM ... +``` + +> 快速验证:进入 `demo/` 目录,用官方脚本对 input-handler 保存的 `input.webm` 跑一遍,只要能输出“你好”之类的文本,就说明端到端数据是健康的。然后再看 dialog-engine 日志是否同样能走到 PyAV 分支。 + +--- + +## 3. 流式协议实现要点 + +和官方 Demo 对比,我们在 `VolcengineAsrProvider` 里做了以下调整(文件路径:`services/dialog-engine/src/dialog_engine/asr/providers/volcengine.py`): + +1. **完整的 config frame** + - JSON 结构必须包含 `user / audio / request`,并设置 `model_name=bigmodel`、开启 `enable_itn / enable_punc / enable_ddc / enable_intermediate_result` 等字段。 +2. **WAV 分片发送** + - 将 PCM 包装成 WAV(带 header)后分 200ms 发送;最后一包带负序号(标记结束)。 + - 每帧 gzip 压缩,header 中 `message_type_specific_flags` 按官方要求设置(0x01 正序、0x03 最后一包)。 +3. **结果解析** + - 识别结果可能嵌套在 `result / utterances / alternatives` 里,必须递归提取;否则只会得到 additions.log_id。 + +> 提示:连接成功但一直只有 `"result":{"additions":{"log_id":...}}`,通常是音频格式/协议不对;日志里若看到 `volcengine.asr.payload {"text":...}` 就标志成功。 + +--- + +## 4. 观察与排错 Checklist + +1. **前端上传音频是否存到 `/tmp/aivtuber_tasks//input.webm`** + - 如果文件为空,基本是前端数据链路的问题。 +2. **demo 脚本能否识别成功** + - `demo/sauc_python/sauc_websocket_demo.py` 支持读取 .env 的凭证,直接通过 PyAV + 官方协议测试。 +3. **dialog-engine 日志检查** + - `audio.preprocessor: using PyAV decoder ...` → 解码无误。 + - `volcengine.asr.payload {"text": "...", "utterances": ...}` → 已拿到识别文本。 + - 若仍是 `"unsupported audio"`,说明 PyAV/ffmpeg 均失败,需要查看容器依赖或音频本身。 +4. **input-handler 返回的错误** + - 新的处理逻辑会把 4xx 的 JSON detail 打到日志里,例如 `dialog_engine_http_error:400{"detail":"unsupported audio"}`,便于准确定位。 + +--- + +## 5. 部署与更新注意事项 + +1. **Dockerfile 更新** + - 离线环境最好在 dialog-engine 镜像里安装 ffmpeg(或包含 libav* 运行库),确保 fallback 不会失败。 +2. **重建服务** + ``` + docker compose -f docker-compose.dev.yml build dialog-engine input-handler + docker compose -f docker-compose.dev.yml up -d dialog-engine input-handler + ``` +3. **确保环境变量正确** + - `ASR_VOLC_APP_KEY / ASR_VOLC_ACCESS_KEY / ASR_VOLC_RESOURCE_ID / ASR_VOLC_ENDPOINT`(Access Key 是否需要 `Bearer;` 前缀,取决于后台设置)。 + +--- + +## 6. 经验教训总结 + +1. **Demo 能跑,不代表服务端能直接复用** + - 官方示例处理的是“本地文件 + 完整协议 + 已封装 WAV”。我们必须在自己的链路里补齐这些步骤。 +2. **音频格式是流式识别的核心** + - 即使凭证配置正确,如果音频仍是 Opus 或变长 PCM,Volcengine 也只会返回 log_id。 + - PyAV 是现阶段最方便的纯 Python 方案,必要时增加 ffmpeg fallback。 +3. **日志要足够详细** + - DEBUG 日志中直接打印 payload 片段、解码流程,可以让排查极快定位失败点。 +4. **出现 `empty_transcript` 要优先看日志** + - 如果音频是空/静音,也会导致最终 transcript 为空,需要先确认音频内容再怀疑 ASR。 + +--- + +通过以上调整,我们已经让 demo 和主项目两条链路表现一致,并能稳定拿到 Volcengine 的识别结果。后续若有新的音频格式接入(如 AAC、MP4 等),重复上述步骤,确保解码 → 分片 → 协议三个环节无误即可。祝调试顺利! + diff --git a/front_end/src/App.vue b/front_end/src/App.vue index 2c3e7a0..5fbcda1 100644 --- a/front_end/src/App.vue +++ b/front_end/src/App.vue @@ -27,6 +27,13 @@ :color="isRecording ? 'error' : 'primary'" @click="toggleMicrophone" > + @@ -64,6 +71,37 @@ + + + + 发送文字消息 + + + + {{ textInputError }} + + + + 取消 + 发送 + + + @@ -75,6 +113,7 @@ import Live2DControls from './components/Live2DControls.vue'; import SubtitleBar from './components/SubtitleBar.vue'; import { useApi } from './composables/useApi'; import { useStreamingChat } from './composables/useStreamingChat'; +import { useSubtitleFeed } from './composables/useSubtitleFeed'; const live2dViewer = ref(null); const modelLoaded = ref(false); @@ -82,28 +121,28 @@ const currentModelPath = ref('/src/live2d/models/Haru/Haru.model3.json'); const showConfigPanel = ref(false); const motions = ref([]); const expressions = ref([]); -const subtitleText = ref(''); +const { subtitleText, appendDelta, replaceText, clearSubtitle, beginStream } = useSubtitleFeed(); const viewerWidth = ref(400); const viewerHeight = ref(600); +const showTextInputDialog = ref(false); +const textInputContent = ref(''); +const textInputError = ref(''); +const isSendingText = ref(false); const { receivedAudioUrl, isRecording, startRecording, stopRecording, recordingError } = useApi(); const { startStreaming, cancelStreaming, onDelta, onDone, onError } = useStreamingChat(); const detachDelta = onDelta((delta) => { - if (!delta) return; - subtitleText.value += delta; + appendDelta(delta); }); -const clearSubtitle = () => { - subtitleText.value = ''; -}; - const detachDone = onDone(() => { - // 保留钩子以便后续扩展,例如记录完成信息 + replaceText(subtitleText.value); }); const detachError = onError((error) => { console.error('字幕流式输出发生错误:', error); + clearSubtitle(); }); const handleModelLoaded = (model) => { @@ -172,9 +211,42 @@ watch(recordingError, (error) => { } }); +const openTextInputDialog = () => { + textInputContent.value = ''; + textInputError.value = ''; + showTextInputDialog.value = true; +}; + +const closeTextInputDialog = () => { + if (isSendingText.value) return; + showTextInputDialog.value = false; + textInputContent.value = ''; + textInputError.value = ''; +}; + +const submitTextInput = async () => { + const content = textInputContent.value.trim(); + if (!content) { + textInputError.value = '请输入要发送的内容。'; + return; + } + textInputError.value = ''; + isSendingText.value = true; + try { + await sendUserText(content, { sessionId: 'text-dialog' }); + showTextInputDialog.value = false; + textInputContent.value = ''; + } catch (error) { + console.error('发送文字消息失败:', error); + textInputError.value = '发送失败,请稍后重试。'; + } finally { + isSendingText.value = false; + } +}; + const sendUserText = async (content, options = {}) => { if (!content) return; - clearSubtitle(); + beginStream(); await startStreaming(options.sessionId || 'default', content, options.payload); }; @@ -243,6 +315,13 @@ defineExpose({ z-index: 5; } +.text-toggle { + position: absolute; + bottom: 24px; + right: 24px; + z-index: 6; +} + .config-dialog { max-height: 80vh; display: flex; diff --git a/front_end/src/components/ChatInterface.vue b/front_end/src/components/ChatInterface.vue index f666fc5..62f9437 100644 --- a/front_end/src/components/ChatInterface.vue +++ b/front_end/src/components/ChatInterface.vue @@ -84,6 +84,38 @@ + + + mdi-pencil + + + + + 发送文字消息 + + + + + 取消 + 发送 + + + @@ -98,6 +130,10 @@ const messages = ref([ // { sender: 'ai', text: '你好!有什么可以帮你的吗?' } ]); const chatListRef = ref(null); // Ref for the message list container +const streamingUserMessageId = ref(null); +const streamingAiMessageId = ref(null); +const isTextInputDialogOpen = ref(false); +const dialogInputText = ref(''); // --- API Composable --- const { @@ -107,6 +143,8 @@ const { isProcessing, processingError, receivedText, + streamingTranscript, + streamingReply, receivedAudioUrl, // Import this to potentially display/play received audio later connectInput, sendTextInput, @@ -176,13 +214,82 @@ const toggleRecording = () => { } }; +const openTextInputDialog = () => { + dialogInputText.value = ''; + isTextInputDialogOpen.value = true; +}; + +const closeTextInputDialog = () => { + isTextInputDialogOpen.value = false; + dialogInputText.value = ''; +}; + +const submitDialogText = async () => { + const text = dialogInputText.value.trim(); + if (!text) { + return; + } + newMessage.value = text; + try { + await sendMessage(); + closeTextInputDialog(); + } catch (error) { + console.error('Dialog text send failed:', error); + } +}; + // --- Watchers --- +watch(streamingTranscript, (newText) => { + if (!newText) { + streamingUserMessageId.value = null; + return; + } + if (!streamingUserMessageId.value) { + const id = Date.now() + Math.random(); + streamingUserMessageId.value = id; + messages.value.push({ sender: 'user', text: newText, id }); + } else { + const msg = messages.value.find((m) => m.id === streamingUserMessageId.value); + if (msg) { + msg.text = newText; + } + } + scrollToBottom(); +}); + +watch(streamingReply, (newText) => { + if (!newText) { + streamingAiMessageId.value = null; + return; + } + if (!streamingAiMessageId.value) { + const id = Date.now() + Math.random(); + streamingAiMessageId.value = id; + messages.value.push({ sender: 'ai', text: newText, id }); + } else { + const msg = messages.value.find((m) => m.id === streamingAiMessageId.value); + if (msg) { + msg.text = newText; + } + } + scrollToBottom(); +}); + // Watch for AI text response watch(receivedText, (newText) => { if (newText) { - // Add the actual AI response directly (no loading message to remove) - messages.value.push({ sender: 'ai', text: newText, id: Date.now() + Math.random() }); + if (streamingAiMessageId.value) { + const msg = messages.value.find((m) => m.id === streamingAiMessageId.value); + if (msg) { + msg.text = newText; + } else { + messages.value.push({ sender: 'ai', text: newText, id: Date.now() + Math.random() }); + } + streamingAiMessageId.value = null; + } else { + messages.value.push({ sender: 'ai', text: newText, id: Date.now() + Math.random() }); + } scrollToBottom(); // Reset receivedText in composable? Or assume it's only set once per response. receivedText.value = ''; // Clear it after processing @@ -272,6 +379,13 @@ onMounted(() => { font-size: 0.9em; padding: 6px 12px; } + +.text-input-fab { + position: fixed; + right: 24px; + bottom: 24px; + z-index: 10; +} .error-text { display: flex; align-items: center; @@ -360,4 +474,4 @@ onMounted(() => { } } - \ No newline at end of file + diff --git a/front_end/src/composables/useApi.js b/front_end/src/composables/useApi.js index 45f3e9b..e42ec19 100644 --- a/front_end/src/composables/useApi.js +++ b/front_end/src/composables/useApi.js @@ -1,4 +1,5 @@ import { ref, shallowRef } from 'vue'; +import { useSubtitleFeed } from './useSubtitleFeed'; // --- Configuration --- // 在实际应用中,这些应该来自配置文件或环境变量 @@ -25,22 +26,81 @@ const isProcessing = ref(false); // 用于表示从发送完成到收到结果 const uploadCompleteConfirmed = ref(false); // 确认 upload_complete 已被后端处理 const processingError = ref(null); // 存储处理过程中或连接中的错误信息 const receivedText = ref(''); // 存储接收到的 AI 文本结果 +const streamingTranscript = ref(''); // 实时 ASR 文本 +const streamingReply = ref(''); // 实时回复片段 // Audio related state const audioChunks = shallowRef([]); // Store received audio binary chunks -const expectedAudioChunks = ref(0); +const expectedAudioChunks = ref(0); // -1 indicates streaming/unknown length const receivedAudioChunkCount = ref(0); const lastReceivedAudioChunkId = ref(-1); // Track the ID for the next expected binary chunk const receivedAudioUrl = ref(null); // URL for the final reassembled audio // Recording specific state +const isStreamingAudio = ref(false); +const streamingCodec = ref(null); +const streamingCongestionWarning = ref(false); const isRecording = ref(false); const recorder = shallowRef(null); const recordedAudioChunks = shallowRef([]); // Store Blob chunks from MediaRecorder const recordingError = ref(null); // Store errors related to recording process +const selectedChunkDurationMs = ref(0); + +const STREAMING_CHUNK_INTERVAL_MS = 250; +const WS_BUFFER_THRESHOLD_BYTES = 512 * 1024; +const MAX_PENDING_CHUNKS = 20; +const MAX_PENDING_BYTES = 5 * 1024 * 1024; +const STREAMING_UNKNOWN_CHUNKS = -1; +const AUDIO_MIME_CANDIDATES = [ + 'audio/webm;codecs=opus', + 'audio/webm', + 'audio/ogg;codecs=opus', + 'audio/wav', +]; + +let pendingAudioChunks = []; +let pendingAudioBytes = 0; +let nextAudioChunkId = 0; +let hasSentAudioStart = false; +let chunkFlushTimer = null; // 使用 shallowRef 以避免深度响应性带来的性能问题,因为 WebSocket 实例不应是深度响应式的 const inputWs = shallowRef(null); const outputWs = shallowRef(null); +const subtitleFeed = useSubtitleFeed(); + +const resetStreamingInputState = () => { + isStreamingAudio.value = false; + streamingCodec.value = null; + streamingCongestionWarning.value = false; + selectedChunkDurationMs.value = 0; + pendingAudioChunks = []; + pendingAudioBytes = 0; + nextAudioChunkId = 0; + hasSentAudioStart = false; + if (chunkFlushTimer !== null) { + clearTimeout(chunkFlushTimer); + chunkFlushTimer = null; + } +}; + +const pickSupportedAudioMimeType = () => { + if (typeof window === 'undefined' || typeof window.MediaRecorder === 'undefined') { + return null; + } + if (typeof window.MediaRecorder.isTypeSupported !== 'function') { + return null; + } + for (const candidate of AUDIO_MIME_CANDIDATES) { + try { + if (window.MediaRecorder.isTypeSupported(candidate)) { + return candidate; + } + } catch (error) { + console.warn('MediaRecorder.isTypeSupported error for', candidate, error); + } + } + return null; +}; // --- WebSocket Logic --- @@ -53,6 +113,8 @@ const resetState = () => { uploadCompleteConfirmed.value = false; processingError.value = null; receivedText.value = ''; + streamingTranscript.value = ''; + streamingReply.value = ''; audioChunks.value = []; expectedAudioChunks.value = 0; receivedAudioChunkCount.value = 0; @@ -69,13 +131,148 @@ const resetState = () => { recorder.value = null; recordedAudioChunks.value = []; recordingError.value = null; + resetStreamingInputState(); + subtitleFeed.clearSubtitle(); }; export function useApi() { + const scheduleAudioChunkFlush = () => { + if (chunkFlushTimer !== null) { + return; + } + chunkFlushTimer = window.setTimeout(() => { + chunkFlushTimer = null; + flushAudioChunkQueue(); + }, 10); + }; + + const flushAudioChunkQueue = (options = {}) => { + if (!inputWs.value || inputWs.value.readyState !== WebSocket.OPEN) { + return; + } + if (!hasSentAudioStart) { + return; + } + + const force = options.force === true; + while (pendingAudioChunks.length > 0) { + if (!force && inputWs.value.bufferedAmount > WS_BUFFER_THRESHOLD_BYTES) { + streamingCongestionWarning.value = true; + scheduleAudioChunkFlush(); + return; + } + + const chunk = pendingAudioChunks.shift(); + pendingAudioBytes -= chunk.blob.size; + + const metadata = { + type: 'audio', + action: 'data_chunk', + chunk_id: chunk.id, + size: chunk.blob.size, + ts: chunk.timestamp, + }; + + try { + inputWs.value.send(JSON.stringify(metadata)); + inputWs.value.send(chunk.blob); + } catch (error) { + console.error('Failed to send streaming audio chunk:', error); + pendingAudioChunks.unshift(chunk); + pendingAudioBytes += chunk.blob.size; + processingError.value = '发送实时音频数据时出错。'; + streamingCongestionWarning.value = true; + return; + } + } + streamingCongestionWarning.value = false; + }; + +const beginAudioStreamingSession = ({ codec, sampleRate, chunkDurationMs }) => { + if (!inputWs.value || inputWs.value.readyState !== WebSocket.OPEN) { + throw new Error('输入 WebSocket 未连接,无法开始音频流。'); + } + if (!taskId.value) { + throw new Error('任务 ID 尚未分配,无法开始音频流。'); + } + if (hasSentAudioStart) { + return; + } + + subtitleFeed.beginStream(); + const startPayload = { + type: 'audio', + action: 'start', + codec: codec || null, + sample_rate: sampleRate || null, + chunk_duration_ms: chunkDurationMs || null, + }; + + inputWs.value.send(JSON.stringify(startPayload)); + console.log('Sent audio streaming start payload:', startPayload); + + streamingCodec.value = codec || null; + selectedChunkDurationMs.value = chunkDurationMs || 0; + hasSentAudioStart = true; + isStreamingAudio.value = true; + isProcessing.value = true; + uploadCompleteConfirmed.value = false; + }; + + const enqueueAudioChunk = (blob) => { + if (!blob || blob.size === 0) { + return; + } + if (!hasSentAudioStart) { + console.warn('音频流尚未初始化,忽略音频块。'); + return; + } + const chunk = { + id: nextAudioChunkId, + blob, + timestamp: Date.now(), + }; + nextAudioChunkId += 1; + pendingAudioChunks.push(chunk); + pendingAudioBytes += blob.size; + + if (pendingAudioChunks.length > MAX_PENDING_CHUNKS || pendingAudioBytes > MAX_PENDING_BYTES) { + streamingCongestionWarning.value = true; + processingError.value = '实时音频发送滞后,请检查网络连接。'; + console.warn('Audio streaming backpressure detected, stopping recorder to prevent overflow.'); + if (recorder.value && recorder.value.state !== 'inactive') { + recorder.value.stop(); + } + return; + } + + flushAudioChunkQueue(); + }; + + const finalizeAudioStreaming = (reason = 'completed') => { + if (!hasSentAudioStart) { + return; + } + flushAudioChunkQueue({ force: true }); + + if (inputWs.value && inputWs.value.readyState === WebSocket.OPEN) { + const stopPayload = { + type: 'audio', + action: 'stop', + total_chunks: nextAudioChunkId, + reason, + }; + inputWs.value.send(JSON.stringify(stopPayload)); + console.log('Sent audio streaming stop payload:', stopPayload); + } + + resetStreamingInputState(); + }; + // --- Input WebSocket Handling --- - const connectInput = () => { +const connectInput = () => { // Return a Promise that resolves with the taskId or rejects on error return new Promise((resolve, reject) => { console.log('Attempting to connect to input WebSocket:', wsInputUrl); @@ -98,6 +295,7 @@ export function useApi() { } resetState(); // Reset all states before new connection attempt + subtitleFeed.clearSubtitle(); inputWs.value = new WebSocket(wsInputUrl); @@ -106,6 +304,9 @@ export function useApi() { isConnectedInput.value = true; processingError.value = null; // Don't resolve yet, wait for task_id + if (hasSentAudioStart && pendingAudioChunks.length > 0) { + scheduleAudioChunkFlush(); + } }; inputWs.value.onmessage = (event) => { @@ -148,12 +349,14 @@ export function useApi() { processingError.value = errorMsg; isConnectedInput.value = false; isProcessing.value = false; + resetStreamingInputState(); reject(new Error(errorMsg)); // Reject the Promise on error }; inputWs.value.onclose = (event) => { console.log('Input WebSocket disconnected:', event.reason); isConnectedInput.value = false; + resetStreamingInputState(); inputWs.value = null; // Reject if closed before task ID was assigned? if (!taskId.value && event.code !== 1000) { @@ -180,6 +383,7 @@ export function useApi() { } console.log(`Sending text input (task ${taskId.value}):`, text); + subtitleFeed.beginStream(); try { // 1. Send metadata JSON const metadata = { @@ -212,38 +416,49 @@ export function useApi() { inputWs.value.close(); inputWs.value = null; // Release the reference } + resetStreamingInputState(); }; // --- Audio Reassembly --- const reassembleAndHandleAudio = () => { - if (receivedAudioChunkCount.value !== expectedAudioChunks.value || audioChunks.value.length !== expectedAudioChunks.value) { - console.error(`Audio reassembly error: Expected ${expectedAudioChunks.value} chunks, received ${receivedAudioChunkCount.value}. Array length: ${audioChunks.value.length}`); - processingError.value = "音频数据接收不完整。"; - resetAudioState(); + if (expectedAudioChunks.value === 0 && audioChunks.value.length === 0) { + console.log("No audio chunks were expected or received."); return; } - if (expectedAudioChunks.value === 0) { - console.log("No audio chunks were expected or received."); - return; // Nothing to reassemble + + const isStreaming = expectedAudioChunks.value === STREAMING_UNKNOWN_CHUNKS; + if (!isStreaming) { + if ( + expectedAudioChunks.value <= 0 || + receivedAudioChunkCount.value !== expectedAudioChunks.value || + audioChunks.value.length !== expectedAudioChunks.value + ) { + console.error( + `Audio reassembly error: Expected ${expectedAudioChunks.value} chunks, received ${receivedAudioChunkCount.value}. Array length: ${audioChunks.value.length}`, + ); + processingError.value = "音频数据接收不完整。"; + resetAudioState(); + return; + } } try { - console.log("Reassembling audio from", audioChunks.value.length, "chunks..."); - // Ensure all chunks are valid before creating blob - const validChunks = audioChunks.value.filter(chunk => chunk instanceof Blob || chunk instanceof ArrayBuffer); - if (validChunks.length !== expectedAudioChunks.value) { - throw new Error("Some received audio chunks are invalid."); + const chunks = Array.isArray(audioChunks.value) ? audioChunks.value : []; + console.log("Reassembling audio from", chunks.length, "chunks..."); + const normalized = chunks + .filter((chunk) => chunk instanceof Blob || chunk instanceof ArrayBuffer) + .map((chunk) => (chunk instanceof Blob ? chunk : new Blob([chunk]))); + if (!normalized.length) { + throw new Error("No valid audio chunks to assemble."); } - // 确认 MP3 为统一输出格式 - const completeAudioBlob = new Blob(validChunks, { type: 'audio/mpeg' }); - console.log("Audio reassembled successfully. Blob size:", completeAudioBlob.size); + const mimeType = isStreaming ? 'audio/wav' : 'audio/mpeg'; + const completeAudioBlob = new Blob(normalized, { type: mimeType }); + console.log("Audio reassembled successfully. Blob size:", completeAudioBlob.size, 'type:', mimeType); - // Clean up previous URL if exists if (receivedAudioUrl.value) { URL.revokeObjectURL(receivedAudioUrl.value); } - // Create Object URL for playback receivedAudioUrl.value = URL.createObjectURL(completeAudioBlob); console.log("Created audio object URL:", receivedAudioUrl.value); @@ -251,7 +466,6 @@ export function useApi() { console.error("Error reassembling audio:", error); processingError.value = "处理接收到的音频时出错。"; } finally { - // Reset audio chunk state after processing resetAudioState(); } }; @@ -302,24 +516,32 @@ export function useApi() { // --- Handle Binary Data First (Audio Chunk) --- if (event.data instanceof Blob || event.data instanceof ArrayBuffer) { console.log(`Received binary audio data chunk (expecting ID ${lastReceivedAudioChunkId.value}).`); - if (lastReceivedAudioChunkId.value !== -1 && lastReceivedAudioChunkId.value < expectedAudioChunks.value) { - // Store the received chunk - // Make sure the array is initialized - if (audioChunks.value.length !== expectedAudioChunks.value && expectedAudioChunks.value > 0) { - console.warn(`Audio chunk array not initialized correctly. Expected ${expectedAudioChunks.value}, got ${audioChunks.value.length}. Reinitializing.`); - audioChunks.value = new Array(expectedAudioChunks.value); - } - if (lastReceivedAudioChunkId.value < audioChunks.value.length) { - audioChunks.value[lastReceivedAudioChunkId.value] = event.data; // Use Blob directly - receivedAudioChunkCount.value++; - console.log(`Stored audio chunk ${lastReceivedAudioChunkId.value}. Received ${receivedAudioChunkCount.value}/${expectedAudioChunks.value}`); + const expectingStreaming = expectedAudioChunks.value === STREAMING_UNKNOWN_CHUNKS; + if (lastReceivedAudioChunkId.value !== -1) { + if (expectingStreaming) { + if (!Array.isArray(audioChunks.value)) { + audioChunks.value = []; + } + audioChunks.value.push(event.data); + receivedAudioChunkCount.value++; + console.log(`Stored streaming audio chunk ${receivedAudioChunkCount.value}`); + } else if (expectedAudioChunks.value > 0 && lastReceivedAudioChunkId.value < expectedAudioChunks.value) { + if (audioChunks.value.length !== expectedAudioChunks.value) { + console.warn(`Audio chunk array not initialized correctly. Expected ${expectedAudioChunks.value}, got ${audioChunks.value.length}. Reinitializing.`); + audioChunks.value = new Array(expectedAudioChunks.value); + } + audioChunks.value[lastReceivedAudioChunkId.value] = event.data; + receivedAudioChunkCount.value++; + console.log(`Stored audio chunk ${lastReceivedAudioChunkId.value}. Received ${receivedAudioChunkCount.value}/${expectedAudioChunks.value}`); } else { - console.error(`Received audio chunk with invalid index: ${lastReceivedAudioChunkId.value}`); - // Handle error? Maybe disconnect? + console.warn('Received binary audio chunk but allocation is not ready.', { + idx: lastReceivedAudioChunkId.value, + expected: expectedAudioChunks.value, + }); } - lastReceivedAudioChunkId.value = -1; // Reset, wait for next metadata chunk + lastReceivedAudioChunkId.value = -1; } else { - console.warn("Received unexpected binary data when not expecting audio chunk ID:", lastReceivedAudioChunkId.value); + console.warn('Received unexpected binary data when not expecting audio chunk ID:', lastReceivedAudioChunkId.value); } return; // Processed binary data, exit handler } @@ -328,14 +550,17 @@ export function useApi() { try { const message = JSON.parse(event.data); - if (message.status === 'success' && message.task_id === currentTaskId) { - console.log('Received successful text result:', message.content); - receivedText.value = message.content || ''; - // If audio is NOT present, processing is fully complete. - if (!message.audio_present) { - isProcessing.value = false; // Processing complete - disconnectOutput(); - } else { + if (message.status === 'success' && message.task_id === currentTaskId) { + console.log('Received successful text result:', message.content); + receivedText.value = message.content || ''; + streamingReply.value = ''; + streamingTranscript.value = ''; + subtitleFeed.replaceText(receivedText.value); + // If audio is NOT present, processing is fully complete. + if (!message.audio_present) { + isProcessing.value = false; // Processing complete + disconnectOutput(); + } else { console.log("Audio is present, preparing to receive chunks..."); // Reset audio state for receiving, wait for first audio_chunk metadata resetAudioState(); @@ -343,10 +568,18 @@ export function useApi() { } } else if (message.type === 'audio_chunk' && message.task_id === currentTaskId) { console.log(`Received audio chunk metadata: ID ${message.chunk_id}/${message.total_chunks}`); - if (message.chunk_id === 0) { // First chunk metadata - expectedAudioChunks.value = message.total_chunks || 0; - audioChunks.value = new Array(expectedAudioChunks.value); // Initialize array - receivedAudioChunkCount.value = 0; // Reset count + if (message.chunk_id === 0 || expectedAudioChunks.value === 0) { + const total = typeof message.total_chunks === 'number' ? message.total_chunks : 0; + if (total > 0) { + expectedAudioChunks.value = total; + audioChunks.value = new Array(total); + } else { + expectedAudioChunks.value = STREAMING_UNKNOWN_CHUNKS; + audioChunks.value = []; + } + receivedAudioChunkCount.value = 0; + } else if (expectedAudioChunks.value === STREAMING_UNKNOWN_CHUNKS && !Array.isArray(audioChunks.value)) { + audioChunks.value = []; } // Verify chunk_id continuity if needed lastReceivedAudioChunkId.value = message.chunk_id; // Set expectation for next binary msg @@ -355,18 +588,46 @@ export function useApi() { isProcessing.value = false; // All processing is now complete reassembleAndHandleAudio(); // Process the collected chunks disconnectOutput(); // Disconnect after processing audio - } else if (message.status === 'error') { - console.error('Output WebSocket error message:', message.error); - processingError.value = message.error || '未知处理错误。'; - isProcessing.value = false; - disconnectOutput(); - } else { + } else if (message.type === 'control' && message.task_id === currentTaskId) { + const action = typeof message.action === 'string' ? message.action.toUpperCase() : ''; + if (action === 'END') { + console.log('Received control END signal for audio stream.'); + isProcessing.value = false; + reassembleAndHandleAudio(); + disconnectOutput(); + } + } else if (message.status === 'streaming' && message.task_id === currentTaskId) { + const eventType = message.event; + if (eventType === 'asr-partial' || eventType === 'asr-final') { + if (typeof message.text === 'string') { + streamingTranscript.value = message.text; + } + if (eventType === 'asr-final') { + streamingTranscript.value = message.text || streamingTranscript.value; + } + } else if (eventType === 'text-delta') { + const deltaContent = typeof message.content === 'string' ? message.content : ''; + if (deltaContent) { + streamingReply.value += deltaContent; + subtitleFeed.appendDelta(deltaContent); + } + } + } else if (message.status === 'error') { + console.error('Output WebSocket error message:', message.error); + processingError.value = message.error || '未知处理错误。'; + isProcessing.value = false; + streamingReply.value = ''; + streamingTranscript.value = ''; + subtitleFeed.clearSubtitle(); + disconnectOutput(); + } else { console.warn('Received unknown message structure on output WebSocket:', message); } } catch (e) { console.error('Error parsing output WebSocket message:', e); processingError.value = '无法解析来自服务器的响应。'; isProcessing.value = false; + subtitleFeed.clearSubtitle(); disconnectOutput(); // Disconnect on parsing error } }; @@ -429,58 +690,75 @@ export function useApi() { console.log("Requesting microphone access..."); recordingError.value = null; // Clear previous errors - recordedAudioChunks.value = []; // Reset chunks + recordedAudioChunks.value = []; // 保留以兼容旧逻辑,但不再用于发送 + streamingTranscript.value = ''; + streamingReply.value = ''; + let stream; try { - const stream = await navigator.mediaDevices.getUserMedia({ audio: true }); + stream = await navigator.mediaDevices.getUserMedia({ audio: true }); console.log("Microphone access granted."); + } catch (err) { + console.error("Error accessing microphone:", err); + if (err.name === 'NotAllowedError' || err.name === 'PermissionDeniedError') { + recordingError.value = "未获得麦克风权限。请在浏览器设置中允许访问。"; + } else { + recordingError.value = `无法访问麦克风: ${err.name}`; + } + isRecording.value = false; + return; + } - // TODO: Consider specific MIME types supported by the backend (e.g., 'audio/wav', 'audio/webm') - // Browsers might have different defaults or support levels. 'audio/webm;codecs=opus' is common. - recorder.value = new MediaRecorder(stream, { mimeType: 'audio/webm;codecs=opus' }); + try { + const preferredMime = pickSupportedAudioMimeType(); + const recorderOptions = preferredMime ? { mimeType: preferredMime } : undefined; + recorder.value = new MediaRecorder(stream, recorderOptions); + const actualMime = recorder.value.mimeType || preferredMime || null; + streamingCodec.value = actualMime; + + const audioTrack = stream.getAudioTracks()[0]; + const trackSettings = audioTrack ? audioTrack.getSettings() : {}; + const sampleRate = trackSettings.sampleRate; + const chunkDurationMs = STREAMING_CHUNK_INTERVAL_MS; + + beginAudioStreamingSession({ + codec: actualMime, + sampleRate, + chunkDurationMs, + }); recorder.value.ondataavailable = (event) => { - if (event.data.size > 0) { - recordedAudioChunks.value.push(event.data); - console.log("Recorded chunk size:", event.data.size); + if (event.data && event.data.size > 0) { + enqueueAudioChunk(event.data); + console.log("Recorded streaming chunk size:", event.data.size); } }; recorder.value.onstop = () => { console.log("[useApi] MediaRecorder onstop event triggered."); - console.log("Recording stopped. Total chunks:", recordedAudioChunks.value.length); isRecording.value = false; - // Stop the tracks to release the microphone stream.getTracks().forEach(track => track.stop()); - // TODO: Trigger sending the collected audio chunks - if (recordedAudioChunks.value.length > 0) { - sendAudioInput(); // Call the function to send data - } else { - console.warn("No audio data recorded."); - } - console.log("[useApi] MediaRecorder onstop event handler finished."); + finalizeAudioStreaming('stopped'); + console.log("[useApi] MediaRecorder onstop handler finished."); }; recorder.value.onerror = (event) => { console.error("MediaRecorder error:", event.error); - recordingError.value = `录音出错: ${event.error.name}`; - isRecording.value = false; - // Stop the tracks on error as well + recordingError.value = `录音出错: ${event.error?.name || '未知错误'}`; stream.getTracks().forEach(track => track.stop()); + isRecording.value = false; + finalizeAudioStreaming('error'); }; - recorder.value.start(); // Start recording + recorder.value.start(chunkDurationMs); isRecording.value = true; - console.log("Recording started."); - + console.log(`Recording started. mimeType=${actualMime}, timeslice=${chunkDurationMs}ms`); } catch (err) { - console.error("Error accessing microphone:", err); - if (err.name === 'NotAllowedError' || err.name === 'PermissionDeniedError') { - recordingError.value = "未获得麦克风权限。请在浏览器设置中允许访问。"; - } else { - recordingError.value = `无法访问麦克风: ${err.name}`; - } - isRecording.value = false; // Ensure recording state is false on error + console.error("Failed to start MediaRecorder:", err); + stream.getTracks().forEach(track => track.stop()); + recordingError.value = `录音初始化失败: ${err.name || err.message}`; + isRecording.value = false; + finalizeAudioStreaming('error'); } }; @@ -496,100 +774,6 @@ export function useApi() { // --- Sending Audio Input --- // Placeholder for the actual audio sending logic - const sendAudioInput = async () => { - if (!recordedAudioChunks.value || recordedAudioChunks.value.length === 0) { - console.error("No recorded audio data to send."); - return; - } - console.log("Preparing to send recorded audio..."); - - // 1. Ensure input WebSocket is connected and has a task ID - if (!inputWs.value || inputWs.value.readyState !== WebSocket.OPEN || !taskId.value) { - console.log("Input WebSocket not ready, attempting to connect..."); - try { - // Use the existing connectInput logic which returns a Promise resolving with taskId - await connectInput(); // This should set up inputWs and taskId - if (!inputWs.value || inputWs.value.readyState !== WebSocket.OPEN || !taskId.value) { - throw new Error("Failed to establish input connection or get task ID after attempt."); - } - console.log("Input connection established for audio sending. Task ID:", taskId.value); - } catch (error) { - console.error("Failed to connect input WebSocket for audio sending:", error); - processingError.value = "发送音频前无法连接到服务器。"; - isProcessing.value = false; // Ensure processing is stopped - // Clear recorded chunks if connection fails? Or allow retry? - // recordedAudioChunks.value = []; - return; - } - } - - // 2. Combine recorded chunks into a single Blob - // The backend expects chunks, but MediaRecorder gives chunks based on time slices or buffer fullness. - // We might need to re-chunk the combined Blob if the backend has strict size limits, - // or send the MediaRecorder chunks directly if the backend can handle variable sizes. - // For simplicity now, let's assume we send the chunks as MediaRecorder produced them. - - console.log(`Sending ${recordedAudioChunks.value.length} audio chunks for task ${taskId.value}...`); - processingError.value = null; // Clear previous errors - isProcessing.value = true; // Start processing indicator - uploadCompleteConfirmed.value = false; // Reset confirmation flag - - try { - const chunksToSend = recordedAudioChunks.value; // Use the chunks directly - - // 3. Send chunks sequentially (metadata + binary) - for (let i = 0; i < chunksToSend.length; i++) { - const chunk = chunksToSend[i]; - const metadata = { - type: "audio", // Or recorder.value.mimeType if backend needs it? - chunk_id: i, - action: "data_chunk" - }; - - if (inputWs.value.readyState !== WebSocket.OPEN) { - throw new Error("Input WebSocket closed unexpectedly during audio chunk sending."); - } - inputWs.value.send(JSON.stringify(metadata)); - console.log(`Sent audio metadata chunk ${i}:`, metadata); - - // Wait a tiny moment for the JSON to likely be processed before sending binary? Might not be needed. - // await new Promise(resolve => setTimeout(resolve, 5)); - - if (inputWs.value.readyState !== WebSocket.OPEN) { - throw new Error("Input WebSocket closed unexpectedly before sending binary audio chunk."); - } - inputWs.value.send(chunk); // Send the Blob directly - console.log(`Sent audio binary chunk ${i}, size: ${chunk.size}`); - - // Optional: Wait for 'File chunk received' confirmation? Could slow things down. - } - - // 4. Send upload complete signal - if (inputWs.value.readyState !== WebSocket.OPEN) { - throw new Error("Input WebSocket closed unexpectedly before sending upload_complete."); - } - const uploadCompleteSignal = { action: "upload_complete" }; - inputWs.value.send(JSON.stringify(uploadCompleteSignal)); - console.log('Sent upload_complete signal for audio.'); - - // 5. Cleanup recorded data after successful initiation of sending - recordedAudioChunks.value = []; - - // Note: isProcessing will be set to false when the output WS receives the final result or error. - // The input WS will disconnect itself upon receiving 'upload_processed'. - - } catch (error) { - console.error("Error sending audio chunks:", error); - processingError.value = `发送音频时出错: ${error.message}`; - isProcessing.value = false; // Stop processing indicator on error - // Consider disconnecting inputWS here on error? It might auto-close anyway. - // disconnectInput(); - // disconnectInput(); - // Clear potentially partially sent data? - recordedAudioChunks.value = []; - } - }; - // --- HTTP TTS 拉取已移除:仅保留双 WS 模式 --- // Return reactive refs and methods @@ -601,10 +785,15 @@ export function useApi() { uploadCompleteConfirmed, processingError, receivedText, + streamingTranscript, + streamingReply, receivedAudioUrl, // Export the audio URL // Export recording state and methods isRecording, recordingError, + isStreamingAudio, + streamingCodec, + streamingCongestionWarning, startRecording, stopRecording, // Export methods diff --git a/front_end/src/composables/useSubtitleFeed.js b/front_end/src/composables/useSubtitleFeed.js new file mode 100644 index 0000000..6aeabbb --- /dev/null +++ b/front_end/src/composables/useSubtitleFeed.js @@ -0,0 +1,38 @@ +import { ref } from 'vue'; + +const subtitleText = ref(''); +const isStreaming = ref(false); + +const beginStream = () => { + subtitleText.value = ''; + isStreaming.value = true; +}; + +const appendDelta = (chunk) => { + if (!chunk) { + return; + } + subtitleText.value += chunk; + isStreaming.value = true; +}; + +const replaceText = (text) => { + subtitleText.value = text || ''; + isStreaming.value = false; +}; + +const clearSubtitle = () => { + subtitleText.value = ''; + isStreaming.value = false; +}; + +export function useSubtitleFeed() { + return { + subtitleText, + isStreaming, + beginStream, + appendDelta, + replaceText, + clearSubtitle, + }; +} diff --git a/requirements-dev.txt b/requirements-dev.txt index 76780ec..03291c9 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -11,7 +11,7 @@ websockets==12.0 python-multipart==0.0.6 # ===== AI和HTTP客户端 ===== -openai==1.7.2 +openai==1.68.0 httpx==0.25.0 aiohttp==3.9.3 diff --git a/services/dialog-engine/requirements.txt b/services/dialog-engine/requirements.txt index 4b82bdf..59cf4eb 100644 --- a/services/dialog-engine/requirements.txt +++ b/services/dialog-engine/requirements.txt @@ -2,13 +2,13 @@ edge-tts==7.2.1 fastapi==0.114.2 faster-whisper==1.0.3 httpx==0.25.0 -numpy==1.26.4 -openai==1.7.2 +av==12.3.0 +numpy>=2.0.2,<3 +openai==1.68.0 pytest==8.4.2 pytest-asyncio==0.23.8 pytest-mock==3.14.0 resampy==0.4.3 -soundfile==0.12.1 redis==5.0.1 uvicorn[standard]==0.30.6 websockets==12.0 diff --git a/services/dialog-engine/src/dialog_engine/app.py b/services/dialog-engine/src/dialog_engine/app.py index c466171..d07ffdb 100644 --- a/services/dialog-engine/src/dialog_engine/app.py +++ b/services/dialog-engine/src/dialog_engine/app.py @@ -2,10 +2,11 @@ import base64 import binascii import json -import logging import os import time -from typing import AsyncGenerator, Dict, Any, List +from typing import AsyncGenerator, Dict, Any, List, Optional + +import logging import redis.asyncio as redis @@ -14,7 +15,7 @@ from .chat_service import ChatService from .audio import AudioBundle, AudioIngestor, AudioPreprocessor, IngestLimits -from .asr import AsrOptions, AsrService +from .asr import AsrOptions, AsrPartial, AsrResult, AsrService from .tts_streamer import stream_text as tts_stream_text from .ltm_outbox import add_event as outbox_add_event, start_flush_task as outbox_start_flush from .internal_state_store import InternalStateStore @@ -22,7 +23,16 @@ app = FastAPI() + +_log_level_name = os.getenv("LOG_LEVEL", "INFO").upper() +_log_level = getattr(logging, _log_level_name, logging.INFO) +logging.basicConfig( + level=_log_level, + format="%(asctime)s %(levelname)s %(name)s %(message)s", +) + logger = logging.getLogger(__name__) +logger.setLevel(_log_level) # Initialize internal state store try: @@ -70,6 +80,23 @@ asr_service = AsrService() +async def _run_tts_background(session_id: str, text: str) -> None: + try: + await tts_stream_text(session_id=session_id, text=text) + except Exception: # pragma: no cover - best-effort logging + logger.exception("chat.tts_failed", extra={"sessionId": session_id}) + + +def _schedule_tts(session_id: str, text: str) -> bool: + if not SYNC_TTS_STREAMING: + return False + clean = (text or "").strip() + if not clean: + return False + asyncio.create_task(_run_tts_background(session_id=session_id, text=clean)) + return True + + def _emit_async_events( *, session_id: str, @@ -115,6 +142,8 @@ def _emit_async_events( "provider": asr_stats.get("provider"), "latency_ms": asr_stats.get("latency_ms"), "duration_seconds": asr_stats.get("duration_seconds"), + "error_code": asr_stats.get("error_code"), + "log_id": asr_stats.get("log_id"), "ts": ts, }, ) @@ -304,6 +333,8 @@ async def chat_audio(request: Request) -> JSONResponse: await chat_service.remember_turn(session_id=session_id, role="assistant", content=reply_text) + audio_streaming = _schedule_tts(session_id, reply_text) + stats = { "asr": { "provider": asr_result.provider or asr_service.provider.name, @@ -330,6 +361,9 @@ async def chat_audio(request: Request) -> JSONResponse: if asr_result.partials: response_payload["partials"] = [partial.text for partial in asr_result.partials] + if audio_streaming: + response_payload["audio"] = {"stream": True} + _emit_async_events( session_id=session_id, body=body, @@ -359,47 +393,75 @@ async def chat_audio_stream(request: Request) -> StreamingResponse: detail = "audio payload too large" if is_duration else "unsupported audio" raise HTTPException(status_code=413 if is_duration else 400, detail=detail) from exc - asr_started = time.perf_counter() - asr_options = AsrOptions( - lang=lang or getattr(asr_cfg, "default_lang", None), - sample_rate=bundle.metadata.sample_rate, - ) - try: - asr_result = await asr_service.transcribe_bundle(bundle, options=asr_options) - except Exception as exc: # pragma: no cover - provider errors converted to HTTP layer - logger.exception("chat.audio.asr_failed", extra={"sessionId": session_id}) - raise HTTPException(status_code=502, detail="asr_failed") from exc + async def event_generator() -> AsyncGenerator[bytes, None]: + asr_started = time.perf_counter() + asr_completed: Optional[float] = None + asr_latency_ms: Optional[float] = None + asr_result: Optional[AsrResult] = None + partials: List[AsrPartial] = [] + transcript = "" + stream_handle = None - asr_completed = time.perf_counter() - asr_latency_ms = (asr_completed - asr_started) * 1000.0 - partials = list(asr_result.partials or []) - transcript = (partials[-1].text if partials else asr_result.text or "").strip() - if not transcript: - raise HTTPException(status_code=502, detail="empty transcript") + reply_segments: List[str] = [] - await chat_service.remember_turn(session_id=session_id, role="user", content=transcript) + asr_options = AsrOptions( + lang=lang or getattr(asr_cfg, "default_lang", None), + sample_rate=bundle.metadata.sample_rate, + ) + stream_handle = asr_service.stream_bundle(bundle, options=asr_options) - meta = dict(meta) - if lang and not meta.get("lang"): - meta["lang"] = lang - meta.setdefault("input_mode", "audio") - meta.setdefault("source", "asr") + try: + async for partial in stream_handle.partials(): + partials.append(partial) + if partial.text: + transcript = partial.text + event_name = "asr-final" if partial.is_final else "asr-partial" + payload: Dict[str, Any] = {"text": partial.text} + if partial.confidence is not None: + payload["confidence"] = partial.confidence + yield _sse_format(event_name, payload) + if await request.is_disconnected(): + await stream_handle.cancel() + return - async def event_generator() -> AsyncGenerator[bytes, None]: - reply_segments: List[str] = [] + asr_result = await stream_handle.final_result() + await stream_handle.wait_closed() + except Exception as exc: # pragma: no cover - provider streaming errors + logger.exception("chat.audio.asr_failed", extra={"sessionId": session_id}) + if stream_handle is not None: + await stream_handle.cancel() + yield _sse_format("error", {"message": "asr_failed"}) + return - for partial in partials: - event_name = "asr-final" if partial.is_final else "asr-partial" - payload: Dict[str, Any] = {"text": partial.text} - if partial.confidence is not None: - payload["confidence"] = partial.confidence - yield _sse_format(event_name, payload) - if await request.is_disconnected(): + transcribe_attr = getattr(asr_service.transcribe_bundle, "__func__", None) + baseline_transcribe = AsrService.transcribe_bundle + current_transcribe = transcribe_attr or asr_service.transcribe_bundle + if current_transcribe is not baseline_transcribe: + try: + asr_result = await asr_service.transcribe_bundle(bundle, options=asr_options) + except Exception as exc: # pragma: no cover - defensive + logger.exception("chat.audio.asr_reconcile_failed", extra={"sessionId": session_id}) + yield _sse_format("error", {"message": "asr_failed"}) return + asr_completed = time.perf_counter() + asr_latency_ms = (asr_completed - asr_started) * 1000.0 + transcript = (asr_result.text or transcript or "").strip() + if not transcript: + yield _sse_format("error", {"message": "empty_transcript"}) + return + + await chat_service.remember_turn(session_id=session_id, role="user", content=transcript) + + meta_local = dict(meta) + if lang and not meta_local.get("lang"): + meta_local["lang"] = lang + meta_local.setdefault("input_mode", "audio") + meta_local.setdefault("source", "asr") + reply_start = time.perf_counter() try: - async for delta in chat_service.stream_reply(session_id=session_id, user_text=transcript, meta=meta): + async for delta in chat_service.stream_reply(session_id=session_id, user_text=transcript, meta=meta_local): reply_segments.append(delta) chunk = {"content": delta, "eos": False} yield _sse_format("text-delta", chunk) @@ -415,10 +477,12 @@ async def event_generator() -> AsyncGenerator[bytes, None]: await chat_service.remember_turn(session_id=session_id, role="assistant", content=reply_text) + audio_streaming = _schedule_tts(session_id, reply_text) + stats = { "asr": { "provider": asr_result.provider or asr_service.provider.name, - "latency_ms": round(asr_latency_ms, 1), + "latency_ms": round(asr_latency_ms, 1) if asr_latency_ms is not None else None, "duration_seconds": asr_result.duration_seconds, }, "chat": { @@ -430,6 +494,8 @@ async def event_generator() -> AsyncGenerator[bytes, None]: }, "total_latency_ms": round((reply_completed - asr_started) * 1000.0, 1), } + stats["asr"]["log_id"] = getattr(asr_service, "last_log_id", None) + stats["asr"]["error_code"] = getattr(asr_service, "last_error_code", None) done_payload = { "sessionId": session_id, @@ -438,11 +504,16 @@ async def event_generator() -> AsyncGenerator[bytes, None]: "stats": stats, } - # Include internal states in the done event + if audio_streaming: + done_payload["audio"] = {"stream": True} + internal_states = await chat_service.get_internal_states(session_id) if internal_states: stats["internal_states"] = internal_states + if asr_result and asr_result.partials: + done_payload["partials"] = [partial.text for partial in asr_result.partials] + yield _sse_format("done", done_payload) _emit_async_events( diff --git a/services/dialog-engine/src/dialog_engine/asr/providers/__init__.py b/services/dialog-engine/src/dialog_engine/asr/providers/__init__.py index 24ae298..eaf04a8 100644 --- a/services/dialog-engine/src/dialog_engine/asr/providers/__init__.py +++ b/services/dialog-engine/src/dialog_engine/asr/providers/__init__.py @@ -10,8 +10,14 @@ except Exception: # pragma: no cover - defensive guard WhisperAsrProvider = None # type: ignore[assignment] +try: + from .volcengine import VolcengineAsrProvider +except Exception: # pragma: no cover - provider optional + VolcengineAsrProvider = None # type: ignore[assignment] + __all__ = [ "AsrProvider", "MockAsrProvider", "WhisperAsrProvider", + "VolcengineAsrProvider", ] diff --git a/services/dialog-engine/src/dialog_engine/asr/providers/base.py b/services/dialog-engine/src/dialog_engine/asr/providers/base.py index e191184..89f1cbf 100644 --- a/services/dialog-engine/src/dialog_engine/asr/providers/base.py +++ b/services/dialog-engine/src/dialog_engine/asr/providers/base.py @@ -1,7 +1,7 @@ from __future__ import annotations import abc -from typing import AsyncGenerator +from typing import AsyncGenerator, AsyncIterable from ..types import AsrOptions, AsrPartial, AsrResult @@ -11,8 +11,12 @@ class AsrProvider(abc.ABC): name: str - async def stream(self, *, audio: bytes, options: AsrOptions) -> AsyncGenerator[AsrPartial, None]: - result = await self.transcribe(audio=audio, options=options) + async def stream(self, *, audio: AsyncIterable[bytes], options: AsrOptions) -> AsyncGenerator[AsrPartial, None]: + collected = bytearray() + async for chunk in audio: + if chunk: + collected.extend(chunk) + result = await self.transcribe(audio=bytes(collected), options=options) partials = list(result.partials or []) final_emitted = False for partial in partials: diff --git a/services/dialog-engine/src/dialog_engine/asr/providers/volcengine.py b/services/dialog-engine/src/dialog_engine/asr/providers/volcengine.py new file mode 100644 index 0000000..3572074 --- /dev/null +++ b/services/dialog-engine/src/dialog_engine/asr/providers/volcengine.py @@ -0,0 +1,539 @@ +from __future__ import annotations + +import asyncio +import json +import logging +import secrets +import time +import uuid +from typing import AsyncIterable, AsyncGenerator, Iterable, List, Optional, Tuple + +import websockets +from websockets.exceptions import ConnectionClosed, ConnectionClosedError, ConnectionClosedOK + +import gzip +import struct + +from ..types import AsrOptions, AsrPartial, AsrResult +from .base import AsrProvider + +logger = logging.getLogger(__name__) + + +class VolcengineAsrError(RuntimeError): + """Raised when Volcengine streaming encounters a fatal error.""" + + def __init__(self, message: str, *, code: Optional[str] = None, log_id: Optional[str] = None) -> None: + super().__init__(message) + self.code = code + self.log_id = log_id + + +class VolcengineAsrProvider(AsrProvider): + """ASR provider backed by Volcengine streaming ASR WebSocket API.""" + + name = "volcengine" + + _DEFAULT_ENDPOINT = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_async" + _PROTOCOL_VERSION = 0x01 + _HEADER_UNITS = 0x01 # 4 bytes + _MESSAGE_TYPE_CONFIG = 0x01 + _MESSAGE_TYPE_AUDIO = 0x02 + _MESSAGE_TYPE_RESPONSE = 0x09 + _MESSAGE_TYPE_ERROR = 0x0F + _SERIALIZATION_NONE = 0x00 + _SERIALIZATION_JSON = 0x01 + _COMPRESSION_NONE = 0x00 + _COMPRESSION_GZIP = 0x01 + _FLAG_POSITIVE_SEQUENCE = 0x01 + _FLAG_LAST_PACKET = 0x02 + _FLAG_EVENT_PRESENT = 0x04 + + def __init__( + self, + *, + endpoint: Optional[str], + app_key: Optional[str], + access_key: Optional[str], + resource_id: Optional[str], + connect_id_prefix: Optional[str] = None, + default_sample_rate: int = 16000, + request_timeout: float = 15.0, + ) -> None: + self._endpoint = (endpoint or self._DEFAULT_ENDPOINT).strip() + self._app_key = (app_key or "").strip() + self._access_key = (access_key or "").strip() + self._resource_id = (resource_id or "").strip() + self._connect_id_prefix = (connect_id_prefix or "dialog-engine").strip() + self._default_sample_rate = default_sample_rate + self._request_timeout = request_timeout + + if not self._app_key or not self._access_key or not self._resource_id: + raise RuntimeError("Volcengine credentials must be configured to use volcengine ASR provider") + + self._last_log_id: Optional[str] = None + self._last_error_code: Optional[str] = None + + async def transcribe(self, *, audio: bytes, options: AsrOptions) -> AsrResult: + async def audio_iter() -> AsyncGenerator[bytes, None]: + if audio: + yield audio + + partials: List[AsrPartial] = [] + async for partial in self.stream(audio=audio_iter(), options=options): + partials.append(partial) + + final_text = partials[-1].text if partials else "" + if not partials or not partials[-1].is_final: + final_partial = AsrPartial(text=final_text, is_final=True) + partials.append(final_partial) + + return AsrResult( + text=partials[-1].text if partials else "", + partials=partials, + duration_seconds=None, + provider=self.name, + ) + + async def stream(self, *, audio: AsyncIterable[bytes], options: AsrOptions) -> AsyncGenerator[AsrPartial, None]: + endpoint = self._endpoint or self._DEFAULT_ENDPOINT + headers = self._build_headers() + connect_id = headers.get("X-Api-Connect-Id") + params_payload = self._build_request_params(options, connect_id) + start_time = time.perf_counter() + self._last_log_id = None + self._last_error_code = None + + try: + async with websockets.connect( + endpoint, + extra_headers=headers, + max_size=None, + ping_interval=None, + ping_timeout=None, + close_timeout=self._request_timeout, + ) as ws: + self._last_log_id = ws.response_headers.get("X-Tt-Logid") if hasattr(ws, "response_headers") else None + logger.info( + "volcengine.asr.connected connect_id=%s logid=%s sample_rate=%s", + connect_id, + self._last_log_id, + params_payload.get("audio", {}).get("rate"), + ) + sequence = 1 + await self._send_config(ws, params_payload, seq=sequence) + sequence += 1 + + # Collect full PCM and send as WAV stream segments (align with official demo) + pcm_buffer = bytearray() + async for chunk in audio: + if chunk: + pcm_buffer.extend(chunk) + + sample_rate = int(params_payload.get("audio", {}).get("rate") or self._default_sample_rate) + channels = int(params_payload.get("audio", {}).get("channel") or 1) + bits = int(params_payload.get("audio", {}).get("bits") or 16) + header = _build_wav_header(sample_rate, channels, bits, len(pcm_buffer)) + + bytes_per_sample = max(1, bits // 8) + bytes_per_200ms = int(sample_rate * 0.2) * channels * bytes_per_sample + + offset = 0 + # First packet carries header + first slice + first_slice = pcm_buffer[offset : offset + bytes_per_200ms] + if first_slice or header: + await self._send_audio_chunk(ws, header + first_slice, seq=sequence, is_last=False) + sequence += 1 + offset += len(first_slice) + + while offset < len(pcm_buffer): + is_last = offset + bytes_per_200ms >= len(pcm_buffer) + slice_bytes = pcm_buffer[offset : offset + bytes_per_200ms] + await self._send_audio_chunk(ws, slice_bytes, seq=sequence, is_last=is_last) + sequence += 1 + offset += len(slice_bytes) + + if len(pcm_buffer) == 0: + await self._send_audio_done(ws, seq=sequence) + + async for partial in self._receive_results(ws): + yield partial + self._last_error_code = None + except (ConnectionClosedError, ConnectionClosedOK, ConnectionClosed) as exc: + code = str(getattr(exc, "code", "")) or None + reason = getattr(exc, "reason", "") + logger.error( + "volcengine.asr.connection_closed code=%s reason=%s connect_id=%s logid=%s", + code, + reason, + connect_id, + self._last_log_id, + ) + self._last_error_code = code or "connection_closed" + raise VolcengineAsrError(reason or "volcengine_connection_closed", code=self._last_error_code, log_id=self._last_log_id) from exc + except VolcengineAsrError: + raise + except Exception as exc: + logger.exception("volcengine.asr.stream_failed connect_id=%s logid=%s", connect_id, self._last_log_id) + self._last_error_code = "stream_failed" + raise VolcengineAsrError(str(exc) or "volcengine_stream_failed", code=self._last_error_code, log_id=self._last_log_id) from exc + finally: + latency_ms = (time.perf_counter() - start_time) * 1000.0 + logger.info( + "asr.volcengine.latency_ms=%.1f connect_id=%s logid=%s error_code=%s", + latency_ms, + connect_id, + self._last_log_id, + self._last_error_code, + ) + + def _build_headers(self) -> dict[str, str]: + connect_suffix = secrets.token_hex(4) + connect_id = f"{self._connect_id_prefix}-{uuid.uuid4()}-{connect_suffix}" if self._connect_id_prefix else str(uuid.uuid4()) + request_id = uuid.uuid4().hex + return { + "X-Api-App-Key": self._app_key, + "X-Api-Access-Key": self._access_key, + "X-Api-Resource-Id": self._resource_id, + "X-Api-Connect-Id": connect_id, + "X-Api-Request-Id": request_id, + } + + def _build_request_params(self, options: AsrOptions, connect_id: Optional[str]) -> dict[str, object]: + sample_rate = options.sample_rate or self._default_sample_rate + reqid = uuid.uuid4().hex + audio_payload: dict[str, object] = { + "format": "pcm", + "codec": "raw", + "rate": sample_rate, + "bits": 16, + "channel": 1, + } + if options.lang: + audio_payload["language"] = options.lang + + request_payload: dict[str, object] = { + "model_name": "bigmodel", + "enable_itn": True, + "enable_punc": True, + "enable_ddc": True, + "show_utterances": True, + "enable_nonstream": False, + "enable_intermediate_result": True, + } + if options.enable_timestamps: + request_payload["enable_timestamp"] = True + request_payload["reqid"] = reqid + + return { + "user": { + "uid": connect_id or reqid, + }, + "audio": audio_payload, + "request": request_payload, + } + + async def _send_config( + self, + ws: websockets.WebSocketClientProtocol, + payload: dict[str, object], + *, + seq: int, + ) -> None: + message = json.dumps(payload, ensure_ascii=False).encode("utf-8") + frame = self._encode_frame( + message_type=self._MESSAGE_TYPE_CONFIG, + flags=self._FLAG_POSITIVE_SEQUENCE, + serialization=self._SERIALIZATION_JSON, + payload=message, + compression=self._COMPRESSION_GZIP, + sequence=seq, + ) + await ws.send(frame) + + async def _send_audio_chunk( + self, + ws: websockets.WebSocketClientProtocol, + chunk: bytes, + *, + seq: int, + is_last: bool, + ) -> None: + flags = self._FLAG_POSITIVE_SEQUENCE + sequence_value = seq + if is_last: + flags |= self._FLAG_LAST_PACKET + sequence_value = -seq + frame = self._encode_frame( + message_type=self._MESSAGE_TYPE_AUDIO, + flags=flags, + serialization=self._SERIALIZATION_NONE, + payload=chunk, + compression=self._COMPRESSION_GZIP, + sequence=sequence_value, + ) + await ws.send(frame) + + async def _send_audio_done(self, ws: websockets.WebSocketClientProtocol, *, seq: int) -> None: + frame = self._encode_frame( + message_type=self._MESSAGE_TYPE_AUDIO, + flags=self._FLAG_POSITIVE_SEQUENCE | self._FLAG_LAST_PACKET, + serialization=self._SERIALIZATION_NONE, + payload=b"", + compression=self._COMPRESSION_GZIP, + sequence=-seq, + ) + await ws.send(frame) + + async def _receive_results(self, ws: websockets.WebSocketClientProtocol) -> AsyncGenerator[AsrPartial, None]: + async for message in ws: + try: + for partial in self._parse_message(message): + yield partial + except VolcengineAsrError: + raise + except Exception as exc: + logger.exception("volcengine.asr.parse_failed logid=%s", self._last_log_id) + raise VolcengineAsrError(str(exc) or "volcengine_parse_failed", log_id=self._last_log_id) from exc + + def _parse_message(self, message: object) -> Iterable[AsrPartial]: + if isinstance(message, bytes): + return self._parse_binary_message(message) + if isinstance(message, str): + return self._extract_partials_from_json(message) + return [] + + def _parse_binary_message(self, data: bytes) -> Iterable[AsrPartial]: + if len(data) < 8: + return [] + header = data[:4] + message_type = (header[1] >> 4) & 0x0F + flags = header[1] & 0x0F + serialization = (header[2] >> 4) & 0x0F + compression = header[2] & 0x0F + + offset = 4 + sequence_present = bool(flags & self._FLAG_POSITIVE_SEQUENCE) + if sequence_present: + if len(data) < offset + 4: + return [] + _sequence_value = int.from_bytes(data[offset : offset + 4], "big", signed=True) + offset += 4 + if flags & self._FLAG_EVENT_PRESENT: + if len(data) < offset + 4: + return [] + offset += 4 # event id + remaining = len(data) - offset + if remaining <= 0: + return [] + payload_size = 0 + if remaining >= 4: + payload_size = int.from_bytes(data[offset : offset + 4], "big", signed=False) + if payload_size > 0 and payload_size <= remaining - 4: + offset += 4 + payload = data[offset : offset + payload_size] + else: + payload = data[offset:] + if compression == self._COMPRESSION_GZIP and payload: + try: + payload = gzip.decompress(payload) + except OSError as exc: + logger.debug("volcengine.asr.gzip_decompress_failed: %s", exc) + raise VolcengineAsrError("volcengine_bad_payload", log_id=self._last_log_id) from exc + + if payload: + json_start = -1 + brace_index = payload.find(b"{") + bracket_index = payload.find(b"[") + candidates = [i for i in (brace_index, bracket_index) if i >= 0] + if candidates: + json_start = min(candidates) + if json_start > 0: + payload = payload[json_start:] + + if message_type == self._MESSAGE_TYPE_ERROR: + message, code = self._extract_error(payload) + if code in {"13", 13} or (isinstance(message, str) and "stream is done" in message.lower()): + logger.info( + "volcengine.asr.stream_done code=%s logid=%s message=%s", + code, + self._last_log_id, + message, + ) + return [] + raise VolcengineAsrError(message, code=code, log_id=self._last_log_id) + + if message_type != self._MESSAGE_TYPE_RESPONSE: + return [] + + if serialization == self._SERIALIZATION_JSON: + text = payload.decode("utf-8", errors="ignore") + if logger.isEnabledFor(logging.DEBUG): + logger.debug("volcengine.asr.payload %s", text[:300]) + return self._extract_partials_from_json(text) + + return [] + + def _extract_partials_from_json(self, text: str) -> Iterable[AsrPartial]: + if not text: + return [] + try: + payload = json.loads(text) + except json.JSONDecodeError: + logger.debug("volcengine.asr.json_decode_failed: %s", text[:200]) + return [] + + results: List[AsrPartial] = [] + + if isinstance(payload, dict): + if self._is_error_payload(payload): + raise VolcengineAsrError(str(payload), log_id=self._last_log_id) + + data_candidates = [] + for key in ("data", "res", "result", "response"): + value = payload.get(key) + if value is not None: + data_candidates.append(value) + if not data_candidates: + data_candidates.append(payload) + + for candidate in data_candidates: + results.extend(self._extract_partials_from_candidate(candidate)) + elif isinstance(payload, list): + for item in payload: + results.extend(self._extract_partials_from_candidate(item)) + + return results + + def _extract_partials_from_candidate(self, candidate: object) -> List[AsrPartial]: + partials: List[AsrPartial] = [] + if isinstance(candidate, dict): + if "result" in candidate and isinstance(candidate["result"], list): + for item in candidate["result"]: + partials.extend(self._extract_partials_from_candidate(item)) + return partials + if "utterances" in candidate and isinstance(candidate["utterances"], list): + for item in candidate["utterances"]: + partials.extend(self._extract_partials_from_candidate(item)) + return partials + if "alternatives" in candidate and isinstance(candidate["alternatives"], list): + for item in candidate["alternatives"]: + partials.extend(self._extract_partials_from_candidate(item)) + return partials + + text_value = ( + candidate.get("text") + or candidate.get("sentence") + or candidate.get("transcript") + or candidate.get("display_text") + ) + if isinstance(text_value, list): + text_value = " ".join(str(x) for x in text_value if x) + if isinstance(text_value, str): + text_value = text_value.strip() + if text_value: + confidence = candidate.get("confidence") or candidate.get("score") + is_final = bool( + candidate.get("is_final") + or candidate.get("final") + or candidate.get("finish") + or candidate.get("type") in {"final", "final_result"} + or candidate.get("event") in {"result", "finish"} + ) + partials.append(AsrPartial(text=text_value, confidence=_safe_float(confidence), is_final=is_final)) + + return partials + + def _encode_frame( + self, + *, + message_type: int, + flags: int, + serialization: int, + payload: bytes, + compression: int = _COMPRESSION_NONE, + sequence: Optional[int] = None, + ) -> bytes: + header = bytearray(4) + header[0] = ((self._PROTOCOL_VERSION & 0x0F) << 4) | (self._HEADER_UNITS & 0x0F) + header[1] = ((message_type & 0x0F) << 4) | (flags & 0x0F) + header[2] = ((serialization & 0x0F) << 4) | (compression & 0x0F) + header[3] = 0x00 + body = bytearray() + if sequence is not None: + body.extend(int(sequence).to_bytes(4, "big", signed=True)) + payload_bytes = payload + if compression == self._COMPRESSION_GZIP: + payload_bytes = gzip.compress(payload) + body.extend(len(payload_bytes).to_bytes(4, "big", signed=False)) + body.extend(payload_bytes) + return bytes(header) + bytes(body) + + def _extract_error(self, payload: bytes) -> Tuple[str, Optional[str]]: + message = payload.decode("utf-8", errors="ignore") if payload else "" + code: Optional[str] = None + if message: + try: + obj = json.loads(message) + if isinstance(obj, dict): + code_val = obj.get("code") or obj.get("error_code") + if code_val is not None: + code = str(code_val) + msg_val = obj.get("message") or obj.get("msg") + if isinstance(msg_val, str) and msg_val.strip(): + message = msg_val.strip() + except json.JSONDecodeError: + pass + return message or "volcengine_error", code + + @staticmethod + def _is_error_payload(payload: dict) -> bool: + error_code = payload.get("code") or payload.get("error_code") + if error_code is None: + return False + try: + error_code = int(error_code) + except (TypeError, ValueError): + return True + return error_code not in {0, 1000, 20000000} + + @property + def last_log_id(self) -> Optional[str]: + return self._last_log_id + + @property + def last_error_code(self) -> Optional[str]: + return self._last_error_code + + +def _safe_float(value: object) -> Optional[float]: + if value is None: + return None + try: + return float(value) + except (TypeError, ValueError): + return None + + +def _build_wav_header(sample_rate: int, channels: int, bits: int, data_len: int) -> bytes: + """Construct a minimal PCM WAV header for given parameters.""" + byte_rate = sample_rate * channels * (bits // 8) + block_align = channels * (bits // 8) + riff_chunk_size = 36 + data_len + header = bytearray() + header.extend(b"RIFF") + header.extend(struct.pack(" None: self._provider = provider or MockAsrProvider() + self._primary_provider = self._provider + self._fallback_provider: AsrProvider = MockAsrProvider() + self._failover_threshold = 0 + self._consecutive_failures = 0 + self._using_fallback = False + self._last_error_code: Optional[str] = None + self._last_log_id: Optional[str] = None @classmethod def from_settings(cls, cfg: AsrSettings | None) -> "AsrService": @@ -40,24 +61,225 @@ def from_settings(cls, cfg: AsrSettings | None) -> "AsrService": cache_dir=cfg.whisper_cache_dir, default_sample_rate=cfg.target_sample_rate, ) + elif provider_name in {"volcengine", "volc", "bytedance"}: + if VolcengineAsrProvider is None: + raise RuntimeError("Volcengine ASR provider unavailable") + provider = VolcengineAsrProvider( + endpoint=cfg.volc_endpoint, + app_key=cfg.volc_app_key, + access_key=cfg.volc_access_key, + resource_id=cfg.volc_resource_id, + connect_id_prefix=cfg.volc_connect_id_prefix, + default_sample_rate=cfg.target_sample_rate, + request_timeout=cfg.volc_timeout_seconds, + ) else: raise RuntimeError(f"unsupported ASR provider: {cfg.provider}") - return cls(provider=provider) + service = cls(provider=provider) + if ( + cfg is not None + and provider is not None + and VolcengineAsrProvider is not None + and isinstance(provider, VolcengineAsrProvider) + ): + service._failover_threshold = max(0, int(cfg.volc_failover_threshold)) + return service async def transcribe_bundle(self, bundle: AudioBundle, *, options: Optional[AsrOptions] = None) -> AsrResult: opts = options or AsrOptions() opts.sample_rate = opts.sample_rate or bundle.metadata.sample_rate - result = await self._provider.transcribe(audio=bundle.pcm, options=opts) + try: + result = await self._provider.transcribe(audio=bundle.pcm, options=opts) + except Exception as exc: + self._record_failure(self._provider, exc) + raise + else: + self._record_success(self._provider) partials = list(result.partials or []) if not partials or not partials[-1].is_final: - partials.append(AsrPartial(text=result.text, is_final=True)) + final_text = partials[-1].text if partials else result.text + partials.append(AsrPartial(text=final_text or result.text, is_final=True)) return AsrResult( text=result.text, partials=partials, duration_seconds=result.duration_seconds, - provider=result.provider, + provider=result.provider or self._provider.name, ) @property def provider(self) -> AsrProvider: return self._provider + + def stream_bundle(self, bundle: AudioBundle, *, options: Optional[AsrOptions] = None) -> "AsrStreamHandle": + opts = options or AsrOptions() + opts.sample_rate = opts.sample_rate or bundle.metadata.sample_rate + audio_iter_factory = self._make_audio_iter_factory( + pcm=bundle.pcm, + metadata=bundle.metadata, + sample_rate=opts.sample_rate or bundle.metadata.sample_rate, + ) + return AsrStreamHandle( + provider=self._provider, + audio_iter_factory=audio_iter_factory, + options=opts, + on_success=self._record_success, + on_failure=self._record_failure, + ) + + def _make_audio_iter_factory( + self, + *, + pcm: bytes, + metadata: AudioMetadata, + sample_rate: int, + ) -> Callable[[], AsyncIterator[bytes]]: + chunk_ms = max(20, self._STREAM_CHUNK_MS) + bytes_per_sample = 2 # 16-bit PCM post-processor standard + chunk_samples = max(1, int(sample_rate * chunk_ms / 1000)) + chunk_bytes = chunk_samples * metadata.channels * bytes_per_sample + if chunk_bytes <= 0: + chunk_bytes = len(pcm) or 1 + + async def audio_iter() -> AsyncIterator[bytes]: + for idx in range(0, len(pcm), chunk_bytes): + chunk = pcm[idx : idx + chunk_bytes] + if chunk: + yield chunk + await asyncio.sleep(0) + + return audio_iter + + def _record_success(self, provider: AsrProvider) -> None: + self._consecutive_failures = 0 + self._last_error_code = None + self._last_log_id = getattr(provider, "last_log_id", None) + + def _record_failure(self, provider: AsrProvider, exc: Exception) -> None: + self._consecutive_failures += 1 + error_code = getattr(exc, "code", None) + if error_code is None: + error_code = exc.__class__.__name__ + if isinstance(error_code, (int, float)): + error_code = str(error_code) + self._last_error_code = error_code + self._last_log_id = getattr(provider, "last_log_id", None) + + logger.warning( + "asr.provider_failure provider=%s consecutive_failures=%s threshold=%s error_code=%s", + provider.name, + self._consecutive_failures, + self._failover_threshold, + self._last_error_code, + ) + + if ( + not self._using_fallback + and self._failover_threshold > 0 + and self._consecutive_failures >= self._failover_threshold + ): + logger.error( + "asr.provider_failover activating fallback after %s failures", + self._consecutive_failures, + ) + self._activate_fallback() + + def _activate_fallback(self) -> None: + if self._using_fallback: + return + self._provider = self._fallback_provider + self._using_fallback = True + logger.warning("asr.provider switched to fallback provider=%s", self._provider.name) + + @property + def last_error_code(self) -> Optional[str]: + return self._last_error_code + + @property + def last_log_id(self) -> Optional[str]: + return self._last_log_id + + +class AsrStreamHandle: + """Utility wrapper that bridges provider streaming results to consumers.""" + + def __init__( + self, + *, + provider: AsrProvider, + audio_iter_factory: Callable[[], AsyncIterator[bytes]], + options: AsrOptions, + on_success: Optional[Callable[[AsrProvider], None]] = None, + on_failure: Optional[Callable[[AsrProvider, Exception], None]] = None, + ) -> None: + self._provider = provider + self._audio_iter_factory = audio_iter_factory + self._options = options + self._on_success = on_success + self._on_failure = on_failure + self._queue: asyncio.Queue[AsrPartial | None] = asyncio.Queue() + loop = asyncio.get_running_loop() + self._result_future: asyncio.Future[AsrResult] = loop.create_future() + self._task = asyncio.create_task(self._run()) + + async def partials(self) -> AsyncIterator[AsrPartial]: + while True: + item = await self._queue.get() + if item is None: + break + yield item + + async def final_result(self) -> AsrResult: + return await asyncio.shield(self._result_future) + + async def wait_closed(self) -> None: + with contextlib.suppress(asyncio.CancelledError, Exception): + await asyncio.shield(self._task) + + async def cancel(self) -> None: + if not self._task.done(): + self._task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._task + if not self._result_future.done(): + self._result_future.cancel() + await self._queue.put(None) + + async def _run(self) -> None: + partials: list[AsrPartial] = [] + try: + async for partial in self._provider.stream( + audio=self._audio_iter_factory(), + options=self._options, + ): + partials.append(partial) + await self._queue.put(partial) + + final_text = partials[-1].text if partials else "" + if not partials or not partials[-1].is_final: + final_partial = AsrPartial(text=final_text, is_final=True) + partials.append(final_partial) + await self._queue.put(final_partial) + + result = AsrResult( + text=partials[-1].text if partials else "", + partials=list(partials), + duration_seconds=None, + provider=self._provider.name, + ) + if self._on_success: + try: + self._on_success(self._provider) + except Exception: # pragma: no cover - defensive + logger.exception("asr.stream_handle.success_callback_failed") + if not self._result_future.done(): + self._result_future.set_result(result) + except Exception as exc: + if self._on_failure: + try: + self._on_failure(self._provider, exc) + except Exception: # pragma: no cover + logger.exception("asr.stream_handle.failure_callback_failed") + if not self._result_future.done(): + self._result_future.set_exception(exc) + finally: + await self._queue.put(None) diff --git a/services/dialog-engine/src/dialog_engine/audio/preprocessor.py b/services/dialog-engine/src/dialog_engine/audio/preprocessor.py index 8d7d327..6de3b3b 100644 --- a/services/dialog-engine/src/dialog_engine/audio/preprocessor.py +++ b/services/dialog-engine/src/dialog_engine/audio/preprocessor.py @@ -1,6 +1,8 @@ from __future__ import annotations import io +import logging +import subprocess from typing import Optional, Tuple try: # pragma: no cover - optional dependency guard @@ -9,7 +11,7 @@ np = None # type: ignore[assignment] try: # pragma: no cover - optional dependency guard - import soundfile as sf + import soundfile as sf # noqa: F401 - imported for monkeypatching in tests except Exception: # pragma: no cover sf = None # type: ignore[assignment] @@ -74,19 +76,77 @@ async def normalize(self, payload: AudioPayload) -> AudioBundle: return AudioBundle(pcm=pcm_bytes, metadata=metadata) async def _extract_pcm(self, payload: AudioPayload) -> Tuple["np.ndarray" | None, int, int]: - if np is None or sf is None: - return None, payload.sample_rate or self._target_sample_rate, payload.channels or self._target_channels + # Prefer PyAV (handles webm/opus, mp4, etc.); fallback to ffmpeg subprocess; otherwise assume raw PCM + target_rate = payload.sample_rate or self._target_sample_rate + channels = payload.channels or self._target_channels data = payload.data if not data: - return np.zeros(0, dtype=np.float32), payload.sample_rate or self._target_sample_rate, self._target_channels + if np is None: + return None, target_rate, channels + return np.zeros(0, dtype=np.float32), target_rate, self._target_channels + + # Try PyAV + try: + import av + from av.audio.resampler import AudioResampler + + container = av.open(io.BytesIO(data), mode="r") + audio_stream = next((s for s in container.streams if s.type == "audio"), None) + if audio_stream is not None: + logging.debug("audio.preprocessor: using PyAV decoder for content_type=%s", payload.content_type) + resampler = AudioResampler(format="s16", layout="mono", rate=target_rate) + pcm_list = [] + for packet in container.demux(audio_stream): + for frame in packet.decode(): + for rframe in resampler.resample(frame): + arr = rframe.to_ndarray() + arr = np.asarray(arr, dtype=np.float32) + arr = arr.reshape(-1, 1) + pcm_list.append(arr / 32768.0) + if pcm_list: + audio_array = np.concatenate(pcm_list, axis=0) + return audio_array, int(target_rate), int(audio_array.shape[1]) + except Exception as exc: + logging.debug("audio.preprocessor: PyAV decode failed: %s", exc) + + # Fallback to ffmpeg subprocess (if available) try: - audio_array, sample_rate = sf.read(io.BytesIO(data), dtype="float32") - except Exception as exc: # pragma: no cover - propagates decode errors - raise ValueError("unsupported audio encoding") from exc - if audio_array.ndim == 1: - audio_array = audio_array[:, None] - channels = audio_array.shape[1] - return audio_array, int(sample_rate), channels + cmd = [ + "ffmpeg", "-v", "quiet", "-y", + "-i", "pipe:0", + "-acodec", "pcm_f32le", + "-ac", "1", + "-ar", str(target_rate), + "-f", "f32le", + "pipe:1", + ] + proc = subprocess.run(cmd, input=data, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True) + raw = proc.stdout + if np is None: + # Cannot form ndarray; treat as raw PCM (float32 little-endian) + return None, target_rate, 1 + if raw: + logging.debug("audio.preprocessor: using ffmpeg fallback decoder for content_type=%s", payload.content_type) + arr = np.frombuffer(raw, dtype=" "np.ndarray": if channels <= 1: diff --git a/services/dialog-engine/src/dialog_engine/llm_client.py b/services/dialog-engine/src/dialog_engine/llm_client.py index dfb4521..10fd210 100644 --- a/services/dialog-engine/src/dialog_engine/llm_client.py +++ b/services/dialog-engine/src/dialog_engine/llm_client.py @@ -3,6 +3,7 @@ """LLM streaming client wrappers for dialog-engine.""" import asyncio +import inspect import logging from collections.abc import AsyncGenerator, Sequence from typing import Any, Dict, Optional @@ -171,7 +172,7 @@ async def stream_chat( finally: if stream is not None: try: - await stream.aclose() + await self._close_stream(stream) except Exception: # pragma: no cover - best effort cleanup logger.debug("llm.stream.close_failed", exc_info=True) @@ -179,6 +180,27 @@ async def stream_chat( raise last_error raise RuntimeError("LLM streaming failed after retries") from last_error + async def _close_stream(self, stream: Any) -> None: + """Best-effort closer compatible with multiple OpenAI SDK versions.""" + + if stream is None: + return + + close_candidate = getattr(stream, "aclose", None) or getattr(stream, "close", None) + if close_candidate: + result = close_candidate() + if inspect.isawaitable(result): + await result + return + + response = getattr(stream, "response", None) + if response is not None: + close_candidate = getattr(response, "aclose", None) or getattr(response, "close", None) + if close_candidate: + result = close_candidate() + if inspect.isawaitable(result): + await result + async def generate_vision_reply( self, messages: Sequence[ChatMessage], diff --git a/services/dialog-engine/src/dialog_engine/settings.py b/services/dialog-engine/src/dialog_engine/settings.py index a0454a1..6a93f54 100644 --- a/services/dialog-engine/src/dialog_engine/settings.py +++ b/services/dialog-engine/src/dialog_engine/settings.py @@ -96,6 +96,13 @@ class AsrSettings: whisper_compute_type: str whisper_beam_size: int whisper_cache_dir: str | None + volc_endpoint: str | None = None + volc_app_key: str | None = None + volc_access_key: str | None = None + volc_resource_id: str | None = None + volc_connect_id_prefix: str | None = None + volc_timeout_seconds: float = 15.0 + volc_failover_threshold: int = 5 @dataclass(frozen=True) @@ -161,6 +168,13 @@ def load_settings() -> Settings: whisper_compute_type=os.getenv("ASR_WHISPER_COMPUTE_TYPE", "int8"), whisper_beam_size=_env_int("ASR_WHISPER_BEAM_SIZE", 1), whisper_cache_dir=os.getenv("ASR_WHISPER_CACHE_DIR"), + volc_endpoint=os.getenv("ASR_VOLC_ENDPOINT", "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_async"), + volc_app_key=os.getenv("ASR_VOLC_APP_KEY"), + volc_access_key=os.getenv("ASR_VOLC_ACCESS_KEY"), + volc_resource_id=os.getenv("ASR_VOLC_RESOURCE_ID"), + volc_connect_id_prefix=os.getenv("ASR_VOLC_CONNECT_ID_PREFIX", "dialog-engine"), + volc_timeout_seconds=_env_float("ASR_VOLC_TIMEOUT_SECONDS", 15.0), + volc_failover_threshold=_env_int("ASR_VOLC_FAILOVER_THRESHOLD", 5), ) return Settings( diff --git a/services/dialog-engine/tests/unit/test_asr_service.py b/services/dialog-engine/tests/unit/test_asr_service.py index e9b1b46..dae3655 100644 --- a/services/dialog-engine/tests/unit/test_asr_service.py +++ b/services/dialog-engine/tests/unit/test_asr_service.py @@ -21,6 +21,13 @@ def _default_asr_settings() -> AsrSettings: whisper_compute_type="int8", whisper_beam_size=1, whisper_cache_dir=None, + volc_endpoint=None, + volc_app_key=None, + volc_access_key=None, + volc_resource_id=None, + volc_connect_id_prefix=None, + volc_timeout_seconds=15.0, + volc_failover_threshold=5, ) diff --git a/services/dialog-engine/tests/unit/test_chat_service.py b/services/dialog-engine/tests/unit/test_chat_service.py index 5bccdf0..04526ff 100644 --- a/services/dialog-engine/tests/unit/test_chat_service.py +++ b/services/dialog-engine/tests/unit/test_chat_service.py @@ -66,6 +66,13 @@ def _make_settings( whisper_compute_type="int8", whisper_beam_size=1, whisper_cache_dir=None, + volc_endpoint=None, + volc_app_key=None, + volc_access_key=None, + volc_resource_id=None, + volc_connect_id_prefix=None, + volc_timeout_seconds=15.0, + volc_failover_threshold=5, ), ) diff --git a/services/gateway-python/main.py b/services/gateway-python/main.py index aba3aac..abee85a 100644 --- a/services/gateway-python/main.py +++ b/services/gateway-python/main.py @@ -1,8 +1,12 @@ import asyncio +import contextlib +import json import logging import os +import time from contextlib import asynccontextmanager -from typing import Dict +from dataclasses import dataclass, field +from typing import Any, Dict, Optional from urllib.parse import urlparse, urlunparse import httpx @@ -30,10 +34,43 @@ DIALOG_ENGINE_URL = os.getenv("DIALOG_ENGINE_URL", "http://dialog-engine:8100") SSE_TIMEOUT = httpx.Timeout(60.0, connect=5.0, read=None, write=10.0) +AUDIO_STREAM_IDLE_TIMEOUT = float(os.getenv("AUDIO_STREAM_IDLE_TIMEOUT", "20.0")) +AUDIO_STREAM_MAX_CHUNK_BYTES = int(os.getenv("AUDIO_STREAM_MAX_CHUNK_BYTES", str(1 * 1024 * 1024))) # 活跃连接跟踪 active_connections: Dict[str, WebSocket] = {} + +@dataclass +class AudioStreamSession: + """Tracks state for a single inbound audio streaming connection.""" + + codec: Optional[str] = None + sample_rate: Optional[int] = None + chunk_duration_ms: Optional[int] = None + started: bool = False + last_chunk_id: int = -1 + total_bytes: int = 0 + last_activity: float = field(default_factory=time.monotonic) + + def touch(self) -> None: + self.last_activity = time.monotonic() + + def mark_started(self, *, codec: Optional[str], sample_rate: Optional[int], chunk_duration_ms: Optional[int]) -> None: + self.started = True + self.codec = codec + self.sample_rate = sample_rate + self.chunk_duration_ms = chunk_duration_ms + self.last_chunk_id = -1 + self.total_bytes = 0 + self.touch() + + def register_chunk(self, chunk_id: int, size: int) -> None: + self.last_chunk_id = chunk_id + self.total_bytes += size + self.touch() + + @asynccontextmanager async def lifespan(app: FastAPI): # 启动时执行 @@ -150,14 +187,282 @@ async def _forward_messages(self, source, destination, direction: str): logger.error(f"Error forwarding message ({direction}): {e}") raise +class StreamingInputProxy: + """Handles frontend -> input-handler WebSocket traffic with audio stream awareness.""" + + def __init__( + self, + *, + idle_timeout: float = AUDIO_STREAM_IDLE_TIMEOUT, + max_chunk_bytes: int = AUDIO_STREAM_MAX_CHUNK_BYTES, + ) -> None: + self._idle_timeout = idle_timeout + self._max_chunk_bytes = max_chunk_bytes + self._connection_id = 0 + + async def handle(self, client_ws: WebSocket, backend_url: str) -> None: + conn_id = self._next_connection_id() + session = AudioStreamSession() + await client_ws.accept() + active_connections[conn_id] = client_ws + logger.info("Client connected to input stream (ID: %s)", conn_id) + + try: + async with websockets.connect(backend_url, max_size=None) as backend_ws: + logger.info("Connected to input-handler backend: %s (ID: %s)", backend_url, conn_id) + backend_logid = backend_ws.response_headers.get("X-Tt-Logid") if hasattr(backend_ws, "response_headers") else None + if backend_logid: + logger.info("gateway.input_stream.logid conn_id=%s logid=%s", conn_id, backend_logid) + tasks = [ + asyncio.create_task(self._forward_client_to_backend(client_ws, backend_ws, session, conn_id)), + asyncio.create_task(self._forward_backend_to_client(client_ws, backend_ws, session, conn_id)), + asyncio.create_task(self._monitor_idle(client_ws, backend_ws, session, conn_id)), + ] + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + for task in pending: + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await task + for task in done: + exc = task.exception() + if exc: + raise exc + except websockets.exceptions.ConnectionClosed as exc: + logger.info("Backend connection closed for %s (code=%s, reason=%s)", conn_id, exc.code, exc.reason) + except Exception as exc: # pragma: no cover - defensive logging + logger.error("Streaming proxy error (%s): %s", conn_id, exc, exc_info=True) + await self._send_error(client_ws, f"gateway_error: {exc}") + with contextlib.suppress(Exception): + await client_ws.close(code=1011, reason="gateway_error") + finally: + active_connections.pop(conn_id, None) + logger.info("Input stream connection %s cleaned up", conn_id) + + async def _forward_client_to_backend( + self, + client_ws: WebSocket, + backend_ws: websockets.WebSocketClientProtocol, + session: AudioStreamSession, + conn_id: str, + ) -> None: + try: + while True: + message = await client_ws.receive() + msg_type = message.get("type") + if msg_type == "websocket.disconnect": + logger.info("Client requested disconnect (%s)", conn_id) + await backend_ws.close(code=1000) + break + + if msg_type != "websocket.receive": + continue + + if "text" in message: + if not await self._handle_text_message(client_ws, backend_ws, session, message["text"], conn_id): + break + elif "bytes" in message: + if not await self._handle_binary_message(client_ws, backend_ws, session, message["bytes"], conn_id): + break + except WebSocketDisconnect: + logger.info("Client WebSocket disconnect raised (%s)", conn_id) + except Exception as exc: # pragma: no cover - defensive logging + logger.error("Error forwarding client -> backend (%s): %s", conn_id, exc, exc_info=True) + raise + + async def _forward_backend_to_client( + self, + client_ws: WebSocket, + backend_ws: websockets.WebSocketClientProtocol, + session: AudioStreamSession, + conn_id: str, + ) -> None: + try: + async for message in backend_ws: + session.touch() + if isinstance(message, str): + await client_ws.send_text(message) + else: + await client_ws.send_bytes(message) + except websockets.exceptions.ConnectionClosed: + logger.info("Backend closed output stream (%s)", conn_id) + except Exception as exc: # pragma: no cover - defensive logging + logger.error("Error forwarding backend -> client (%s): %s", conn_id, exc, exc_info=True) + raise + + async def _monitor_idle( + self, + client_ws: WebSocket, + backend_ws: websockets.WebSocketClientProtocol, + session: AudioStreamSession, + conn_id: str, + ) -> None: + try: + while True: + await asyncio.sleep(self._idle_timeout / 2) + if not session.started: + continue + idle_for = time.monotonic() - session.last_activity + if idle_for >= self._idle_timeout: + warn = f"audio_stream_idle_timeout ({idle_for:.1f}s)" + logger.warning("Idle audio stream detected (%s): %s", conn_id, warn) + await self._send_error(client_ws, warn) + await backend_ws.close(code=1011, reason="idle_timeout") + with contextlib.suppress(Exception): + await client_ws.close(code=4000, reason="idle_timeout") + break + except asyncio.CancelledError: + raise + except Exception as exc: # pragma: no cover - defensive logging + logger.error("Idle monitor error (%s): %s", conn_id, exc, exc_info=True) + + async def _handle_text_message( + self, + client_ws: WebSocket, + backend_ws: websockets.WebSocketClientProtocol, + session: AudioStreamSession, + payload: str, + conn_id: str, + ) -> bool: + try: + data = json.loads(payload) + except json.JSONDecodeError: + await backend_ws.send(payload) + session.touch() + return True + + msg_type = data.get("type") + action = str(data.get("action") or "").lower() + session.touch() + + if msg_type != "audio": + await backend_ws.send(json.dumps(data, ensure_ascii=False)) + return True + + if action == "start": + if session.started: + await self._send_error(client_ws, "audio_stream_already_started") + return False + session.mark_started( + codec=data.get("codec"), + sample_rate=self._safe_int(data.get("sample_rate")), + chunk_duration_ms=self._safe_int(data.get("chunk_duration_ms")), + ) + logger.info( + "Audio stream started (%s): codec=%s sample_rate=%s chunk=%sms", + conn_id, + session.codec, + session.sample_rate, + session.chunk_duration_ms, + ) + start_forward = dict(data) + start_forward["action"] = "stream_start" + await backend_ws.send(json.dumps(start_forward, ensure_ascii=False)) + return True + + if action == "data_chunk": + if not session.started: + await self._send_error(client_ws, "audio_stream_not_started") + return False + chunk_id = self._safe_int(data.get("chunk_id")) + if chunk_id is None: + await self._send_error(client_ws, "missing_chunk_id") + return False + expected_next = session.last_chunk_id + 1 + if chunk_id != expected_next: + warn = f"chunk_id_mismatch expected={expected_next} got={chunk_id}" + logger.warning("Chunk mismatch (%s): %s", conn_id, warn) + await self._send_error(client_ws, warn) + return False + enriched = dict(data) + if session.codec: + enriched.setdefault("codec", session.codec) + if session.sample_rate: + enriched.setdefault("sample_rate", session.sample_rate) + if session.chunk_duration_ms: + enriched.setdefault("chunk_duration_ms", session.chunk_duration_ms) + await backend_ws.send(json.dumps(enriched, ensure_ascii=False)) + return True + + if action in {"stop", "upload_complete"}: + if not session.started: + await self._send_error(client_ws, "audio_stream_not_started") + return False + stop_payload: Dict[str, Any] = dict(data) + stop_payload["action"] = "upload_complete" + stop_payload["total_chunks"] = session.last_chunk_id + 1 + if session.codec: + stop_payload.setdefault("codec", session.codec) + if session.sample_rate: + stop_payload.setdefault("sample_rate", session.sample_rate) + await backend_ws.send(json.dumps(stop_payload, ensure_ascii=False)) + session.started = False + logger.info( + "Audio stream completed (%s): total_chunks=%s total_bytes=%s", + conn_id, + stop_payload["total_chunks"], + session.total_bytes, + ) + return True + + if action == "cancel": + logger.info("Audio stream canceled by client (%s)", conn_id) + cancel_payload = dict(data) + await backend_ws.send(json.dumps(cancel_payload, ensure_ascii=False)) + session.started = False + await client_ws.close(code=4001, reason="client_cancel") + await backend_ws.close(code=1000, reason="client_cancel") + return False + + await backend_ws.send(json.dumps(data, ensure_ascii=False)) + return True + + async def _handle_binary_message( + self, + client_ws: WebSocket, + backend_ws: websockets.WebSocketClientProtocol, + session: AudioStreamSession, + payload: bytes, + conn_id: str, + ) -> bool: + size = len(payload) + if size > self._max_chunk_bytes: + await self._send_error(client_ws, f"chunk_too_large ({size} > {self._max_chunk_bytes})") + return False + if not session.started: + await self._send_error(client_ws, "audio_stream_not_started") + return False + + session.register_chunk(session.last_chunk_id + 1, size) + await backend_ws.send(payload) + return True + + async def _send_error(self, client_ws: WebSocket, message: str) -> None: + payload = json.dumps({"type": "error", "message": message}) + with contextlib.suppress(Exception): + await client_ws.send_text(payload) + + def _safe_int(self, value: Any) -> Optional[int]: + if value is None: + return None + try: + return int(value) + except (ValueError, TypeError): + return None + + def _next_connection_id(self) -> str: + self._connection_id += 1 + return f"stream_input_{self._connection_id}" + # 初始化代理 proxy = WebSocketProxy() +audio_proxy = StreamingInputProxy() + @app.websocket("/ws/input") async def proxy_input(websocket: WebSocket): """代理输入WebSocket连接到input-handler服务""" backend_url = f"{BACKEND_SERVICES['input']}/ws/input" - await proxy.proxy_websocket(websocket, backend_url, "input") + await audio_proxy.handle(websocket, backend_url) @app.websocket("/ws/output/{task_id}") async def proxy_output(websocket: WebSocket, task_id: str): diff --git a/services/gateway-python/requirements-test.txt b/services/gateway-python/requirements-test.txt index f0ec73c..6824970 100644 --- a/services/gateway-python/requirements-test.txt +++ b/services/gateway-python/requirements-test.txt @@ -2,4 +2,4 @@ pytest>=7.4.0 httpx>=0.27.0 pytest-asyncio>=0.23.0 websockets>=12.0 -starlette>=0.37.2 +starlette>=0.27.0,<0.28 diff --git a/services/gateway-python/tests/unit/test_websocket_proxy.py b/services/gateway-python/tests/unit/test_websocket_proxy.py index e555db4..f477608 100644 --- a/services/gateway-python/tests/unit/test_websocket_proxy.py +++ b/services/gateway-python/tests/unit/test_websocket_proxy.py @@ -1,12 +1,18 @@ import asyncio -from unittest.mock import AsyncMock, Mock, patch +import json +from unittest.mock import AsyncMock, patch import pytest import websockets from websockets.exceptions import ConnectionClosed # 导入要测试的模块 -from main import WebSocketProxy, app +from main import ( + AudioStreamSession, + StreamingInputProxy, + WebSocketProxy, + app, +) @pytest.mark.asyncio @@ -137,3 +143,94 @@ async def test_forward_messages_general_exception(): # 调用转发方法,应该抛出异常 with pytest.raises(Exception, match="General error"): await proxy._forward_messages(mock_source, mock_destination, "test_direction") + + +@pytest.mark.asyncio +async def test_streaming_proxy_start_message(): + proxy = StreamingInputProxy() + session = AudioStreamSession() + client_ws = AsyncMock() + backend_ws = AsyncMock() + + payload = json.dumps({ + "type": "audio", + "action": "start", + "codec": "audio/webm;codecs=opus", + "sample_rate": 16000, + "chunk_duration_ms": 200, + }) + + result = await proxy._handle_text_message(client_ws, backend_ws, session, payload, "conn1") + assert result is True + assert session.started is True + backend_ws.send.assert_called_once() + forwarded = json.loads(backend_ws.send.call_args[0][0]) + assert forwarded["action"] == "stream_start" + assert forwarded["codec"] == "audio/webm;codecs=opus" + + +@pytest.mark.asyncio +async def test_streaming_proxy_data_chunk_enriches_metadata(): + proxy = StreamingInputProxy() + session = AudioStreamSession() + session.mark_started(codec="audio/webm", sample_rate=16000, chunk_duration_ms=250) + client_ws = AsyncMock() + backend_ws = AsyncMock() + + payload = json.dumps({ + "type": "audio", + "action": "data_chunk", + "chunk_id": 0, + "size": 1234, + }) + + result = await proxy._handle_text_message(client_ws, backend_ws, session, payload, "conn2") + assert result is True + backend_ws.send.assert_called_once() + forwarded = json.loads(backend_ws.send.call_args[0][0]) + assert forwarded["chunk_id"] == 0 + assert forwarded["sample_rate"] == 16000 + assert forwarded["chunk_duration_ms"] == 250 + + +@pytest.mark.asyncio +async def test_streaming_proxy_stop_converts_to_upload_complete(): + proxy = StreamingInputProxy() + session = AudioStreamSession() + session.mark_started(codec="audio/webm", sample_rate=16000, chunk_duration_ms=250) + session.register_chunk(0, 1024) + client_ws = AsyncMock() + backend_ws = AsyncMock() + + payload = json.dumps({ + "type": "audio", + "action": "stop", + "reason": "completed", + }) + + result = await proxy._handle_text_message(client_ws, backend_ws, session, payload, "conn3") + assert result is True + backend_ws.send.assert_called_once() + forwarded = json.loads(backend_ws.send.call_args[0][0]) + assert forwarded["action"] == "upload_complete" + assert forwarded["total_chunks"] == 1 + assert forwarded["sample_rate"] == 16000 + assert session.started is False + + +@pytest.mark.asyncio +async def test_streaming_proxy_chunk_size_limit_triggers_error(): + proxy = StreamingInputProxy(max_chunk_bytes=4) + session = AudioStreamSession() + session.mark_started(codec="audio/webm", sample_rate=16000, chunk_duration_ms=250) + client_ws = AsyncMock() + backend_ws = AsyncMock() + + payload = b"toolarge" + + result = await proxy._handle_binary_message(client_ws, backend_ws, session, payload, "conn4") + assert result is False + client_ws.send_text.assert_called_once() + err = json.loads(client_ws.send_text.call_args[0][0]) + assert err["type"] == "error" + assert "chunk_too_large" in err["message"] diff --git a/services/input-handler-python/main.py b/services/input-handler-python/main.py index d91d4ed..455aba4 100644 --- a/services/input-handler-python/main.py +++ b/services/input-handler-python/main.py @@ -299,20 +299,35 @@ async def _handle_text_task(self, task_id: str, content: str) -> None: async def _handle_audio_task(self, task_id: str, audio_file: Path) -> None: try: - result = await self._invoke_dialog_engine_audio(task_id, audio_file) - payload = { - "status": "success", - "sessionId": task_id, - "text": result.get("reply", ""), - "transcript": result.get("transcript", ""), - "stats": result.get("stats"), - "source": "dialog-engine", - "input_mode": "audio", - } - partials = result.get("partials") - if partials: - payload["partials"] = partials - await self._publish_response(task_id, payload) + async for event, payload in self._stream_dialog_engine_audio(task_id, audio_file): + event_name = (event or "").lower() + if event_name in {"asr-partial", "asr-final", "text-delta"}: + await self._publish_stream_event(task_id, event_name, payload if isinstance(payload, dict) else {}) + continue + if event_name == "done": + data = payload if isinstance(payload, dict) else {} + response_payload = { + "status": "success", + "sessionId": task_id, + "text": data.get("reply", ""), + "transcript": data.get("transcript", ""), + "stats": data.get("stats"), + "source": "dialog-engine", + "input_mode": "audio", + } + if "partials" in data: + response_payload["partials"] = data["partials"] + if "audio" in data: + response_payload["audio"] = data["audio"] + await self._publish_response(task_id, response_payload) + return + if event_name == "error": + message = "" + if isinstance(payload, dict): + message = payload.get("message") or "" + await self._publish_error(task_id, message or "dialog_engine_failed") + return + await self._publish_error(task_id, "dialog_engine_stream_incomplete") except Exception as exc: logger.error(f"Dialog-engine audio handling failed for task {task_id}: {exc}") await self._publish_error(task_id, str(exc) or "dialog_engine_failed") @@ -350,6 +365,36 @@ async def _handle_image_task( await self._publish_error(task_id, str(exc) or "dialog_engine_failed") async def _publish_response(self, task_id: str, payload: Dict[str, Any]) -> None: + await self._publish_payload(task_id, payload) + + async def _publish_error(self, task_id: str, message: str) -> None: + payload = { + "status": "error", + "task_id": task_id, + "error": message, + "source": "dialog-engine", + } + await self._publish_payload(task_id, payload) + + async def _publish_stream_event(self, task_id: str, event: str, data: Dict[str, Any]) -> None: + payload = { + "status": "streaming", + "task_id": task_id, + "event": event, + "text": data.get("text"), + "confidence": data.get("confidence"), + "content": data.get("content"), + "is_final": data.get("is_final"), + "source": "dialog-engine", + "input_mode": "audio", + } + if "eos" in data: + payload["eos"] = data.get("eos") + if "log_id" in data: + payload["log_id"] = data.get("log_id") + await self._publish_payload(task_id, payload) + + async def _publish_payload(self, task_id: str, payload: Dict[str, Any]) -> None: if not redis_client: logger.error("Redis client not available; cannot publish response") return @@ -360,15 +405,6 @@ async def _publish_response(self, task_id: str, payload: Dict[str, Any]) -> None except Exception as exc: logger.error(f"Failed to publish response for task {task_id}: {exc}") - async def _publish_error(self, task_id: str, message: str) -> None: - payload = { - "status": "error", - "task_id": task_id, - "error": message, - "source": "dialog-engine", - } - await self._publish_response(task_id, payload) - async def _stream_dialog_engine(self, task_id: str, content: str) -> Tuple[str, Dict[str, Any]]: url = f"{DIALOG_ENGINE_URL.rstrip('/')}{TEXT_STREAM_ENDPOINT}" payload = { @@ -425,7 +461,7 @@ async def _stream_dialog_engine(self, task_id: str, content: str) -> Tuple[str, reply_text = "".join(deltas) return reply_text, stats - async def _invoke_dialog_engine_audio(self, task_id: str, audio_file: Path) -> Dict[str, Any]: + async def _stream_dialog_engine_audio(self, task_id: str, audio_file: Path): url = f"{DIALOG_ENGINE_URL.rstrip('/')}{AUDIO_ENDPOINT}" try: audio_bytes = audio_file.read_bytes() @@ -442,11 +478,52 @@ async def _invoke_dialog_engine_audio(self, task_id: str, audio_file: Path) -> D "meta": {"source": "input-handler"}, } try: + headers = {"Accept": "text/event-stream"} async with httpx.AsyncClient(timeout=HTTP_TIMEOUT) as client: - resp = await client.post(url, json=body) - resp.raise_for_status() - return resp.json() + async with client.stream( + "POST", + url + "/stream", + json=body, + headers=headers, + ) as resp: + if resp.status_code >= 400: + # Read error body before leaving stream context to avoid httpx read() error + content = await resp.aread() + text = content.decode("utf-8", errors="ignore") if content else "" + raise RuntimeError(f"dialog_engine_http_error:{resp.status_code}:{text}") + log_id = resp.headers.get("X-Tt-Logid") + if log_id: + logger.info("dialog-engine audio stream task_id=%s logid=%s", task_id, log_id) + current_event = "message" + async for line in resp.aiter_lines(): + if line == "": + current_event = "message" + continue + if line.startswith(":"): + continue + if line.lower().startswith("event:"): + current_event = line.split(":", 1)[1].strip() or "message" + continue + if line.lower().startswith("data:"): + payload_raw = line.split(":", 1)[1].strip() + if not payload_raw: + continue + try: + data_obj = json.loads(payload_raw) + except json.JSONDecodeError: + logger.debug("Non-JSON SSE payload ignored: %s", payload_raw[:80]) + continue + if log_id and isinstance(data_obj, dict) and not data_obj.get("log_id"): + data_obj["log_id"] = log_id + yield current_event, data_obj + if current_event.lower() in {"done", "error"}: + return except httpx.HTTPStatusError as exc: + # Fallback path if raise_for_status is used elsewhere + try: + await exc.response.aread() + except Exception: + pass try: detail = exc.response.json() except ValueError: diff --git a/services/output-handler-python/main.py b/services/output-handler-python/main.py index 49f4771..42ff805 100644 --- a/services/output-handler-python/main.py +++ b/services/output-handler-python/main.py @@ -26,6 +26,9 @@ task_status: Dict[str, str] = {} ingest_ws: Optional[WebSocket] = None # dialog-engine upstream connection _chunk_seq: Dict[str, int] = {} # per-session chunk counters +streaming_events: Dict[str, asyncio.Event] = {} # signal streaming TTS completion + +STREAMING_AUDIO_TIMEOUT = float(os.getenv("STREAMING_AUDIO_TIMEOUT", "120")) async def init_redis(): global redis_client @@ -70,11 +73,20 @@ async def handle_connection(self, websocket: WebSocket, task_id: str): await websocket.accept() active_connections[task_id] = websocket task_status[task_id] = "connected" + stream_event = asyncio.Event() + streaming_events[task_id] = stream_event logger.info(f"Output connection established for task_id: {task_id}") try: # 等待处理结果 - await self._wait_for_result(websocket, task_id) + expects_stream = await self._wait_for_result(websocket, task_id) + if expects_stream: + logger.info("Waiting for streaming audio to finish for task %s", task_id) + try: + await asyncio.wait_for(stream_event.wait(), timeout=STREAMING_AUDIO_TIMEOUT) + logger.info("Streaming audio completed for task %s", task_id) + except asyncio.TimeoutError: + logger.warning("Timed out waiting for streaming audio completion for task %s", task_id) except WebSocketDisconnect: logger.info(f"Output connection disconnected, task_id: {task_id}") except Exception as e: @@ -91,14 +103,18 @@ async def handle_connection(self, websocket: WebSocket, task_id: str): del active_connections[task_id] if task_id in task_status: del task_status[task_id] + event = streaming_events.pop(task_id, None) + if event is not None: + event.set() - async def _wait_for_result(self, websocket: WebSocket, task_id: str): + async def _wait_for_result(self, websocket: WebSocket, task_id: str) -> bool: + streaming_pending = False if not redis_client: await websocket.send_text(json.dumps({ "status": "error", "error": "Redis connection not available" })) - return + return streaming_pending pubsub = None try: @@ -134,7 +150,12 @@ async def _wait_for_result(self, websocket: WebSocket, task_id: str): try: logger.info(f"Received message on {channel_name}: {message['data'][:100]}...") response_data = json.loads(message["data"]) - await self._send_response(websocket, task_id, response_data) + status = str(response_data.get("status") or "").lower() + if status == "streaming": + await self._send_stream_event(websocket, task_id, response_data) + continue + expects_stream = await self._send_response(websocket, task_id, response_data) + streaming_pending = streaming_pending or expects_stream task_status[task_id] = "completed" logger.info(f"Successfully processed response for task {task_id}") break @@ -168,8 +189,29 @@ async def _wait_for_result(self, websocket: WebSocket, task_id: str): logger.debug(f"Cleaned up pubsub for task {task_id}") except Exception as e: logger.error(f"Error cleaning up pubsub: {e}") + return streaming_pending - async def _send_response(self, websocket: WebSocket, task_id: str, response_data: dict): + async def _send_stream_event(self, websocket: WebSocket, task_id: str, payload: dict) -> None: + event = payload.get("event") + message = { + "status": "streaming", + "task_id": task_id, + "event": event, + "text": payload.get("text"), + "confidence": payload.get("confidence"), + "content": payload.get("content"), + "is_final": payload.get("is_final"), + "eos": payload.get("eos"), + } + if event and event.lower() == "asr-final": + task_status[task_id] = "asr-final" + try: + await websocket.send_text(json.dumps(message)) + except Exception as exc: + logger.error(f"Failed to send streaming event for task {task_id}: {exc}") + + async def _send_response(self, websocket: WebSocket, task_id: str, response_data: dict) -> bool: + streaming_audio = False try: status = str(response_data.get("status") or "success").lower() if status != "success": @@ -183,16 +225,18 @@ async def _send_response(self, websocket: WebSocket, task_id: str, response_data error_payload["meta"] = response_data["meta"] await websocket.send_text(json.dumps(error_payload)) logger.info(f"Sent error response for task {task_id}") - return + return streaming_audio content = response_data.get("text") if not isinstance(content, str): content = response_data.get("reply", "") + audio_payload = response_data.get("audio") if isinstance(response_data.get("audio"), dict) else {} + streaming_audio = bool(audio_payload.get("stream")) text_response = { "status": "success", "task_id": task_id, "content": content or "", - "audio_present": bool(response_data.get("audio_file") or response_data.get("audio")), + "audio_present": bool(response_data.get("audio_file") or audio_payload), "transcript": response_data.get("transcript"), "stats": response_data.get("stats"), "source": response_data.get("source", "dialog-engine"), @@ -206,6 +250,7 @@ async def _send_response(self, websocket: WebSocket, task_id: str, response_data audio_file = response_data.get("audio_file") if audio_file: await self._send_audio_chunks(websocket, task_id, audio_file) + streaming_audio = False # handled synchronously via file chunks except Exception as e: logger.error(f"Error sending response: {e}") @@ -213,6 +258,8 @@ async def _send_response(self, websocket: WebSocket, task_id: str, response_data "status": "error", "error": str(e) })) + streaming_audio = False + return streaming_audio async def _send_audio_chunks(self, websocket: WebSocket, task_id: str, audio_file: str): try: @@ -289,13 +336,20 @@ async def relay_speech_chunk(self, session_id: str, pcm_b64: str, seq: Optional[ logger.error(f"Failed to relay chunk to client {session_id}: {e}") async def relay_control(self, session_id: str, action: str): + upper_action = action.upper() ws = active_connections.get(session_id) - if not ws: - return - try: - await ws.send_text(json.dumps({"type": "control", "action": action, "task_id": session_id})) - except Exception: - pass + if ws: + try: + payload = {"type": "control", "action": action, "task_id": session_id} + await ws.send_text(json.dumps(payload)) + if upper_action == "END": + await ws.send_text(json.dumps({"type": "audio_complete", "task_id": session_id})) + except Exception: + pass + if upper_action in {"END", "STOP_ACK"}: + event = streaming_events.get(session_id) + if event: + event.set() # 初始化处理器 output_handler = OutputHandler()