Skip to content

Commit c2e5610

Browse files
feat: wire MCP server and HTTP API through ModalityBus
Both server.py (MCP) and http_api.py (REST) now route through the ModalityBus abstraction layer: server.py: - Initialize _bus singleton with VoiceModule at module level - diagnostics() includes bus.health() and bus.hud() - speak() validates voice through bus-registered module - vad_check() uses bus perception path (gate.check) - Streaming playback via AdaptivePlayer preserved unchanged http_api.py: - Imports shared _bus from server.py (with safe fallback) - /v1/synthesize and /v1/audio/speech validate voice via bus - /v1/vad routes through bus gate + perceive - /diagnostics includes bus state All 10 existing tests pass. No MCP tool signature changes. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent aabb079 commit c2e5610

2 files changed

Lines changed: 353 additions & 31 deletions

File tree

http_api.py

Lines changed: 113 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import struct
1616
import time
1717
import uuid
18+
import wave
1819
from collections import OrderedDict
1920
from threading import Lock
2021

@@ -24,14 +25,72 @@
2425

2526
from bus import ModalityBus
2627
from engine import MODELS, generate_audio, get_loaded_engines, resolve_model
27-
from modality import EncodedOutput
28+
from modality import EncodedOutput, ModalityType
2829
from modules.text import TextModule
2930
from modules.voice import VoiceModule
3031
from vad import detect_speech_file, is_hallucination
3132
from vad import is_model_loaded as vad_loaded
3233

3334
app = FastAPI(title="Mod³", description="Local multi-model TTS on Apple Silicon")
3435

36+
try:
37+
from server import _bus as _shared_bus
38+
except Exception:
39+
_shared_bus = ModalityBus()
40+
41+
_bus = _shared_bus
42+
_bus_vad_lock = Lock()
43+
44+
45+
def _ensure_bus_modules() -> None:
46+
modules = getattr(_bus, "_modules", {})
47+
if ModalityType.TEXT not in modules:
48+
_bus.register(TextModule())
49+
if ModalityType.VOICE not in modules:
50+
_bus.register(VoiceModule())
51+
52+
53+
def _get_voice_module() -> VoiceModule | None:
54+
module = getattr(_bus, "_modules", {}).get(ModalityType.VOICE)
55+
return module if isinstance(module, VoiceModule) else None
56+
57+
58+
def _resolve_voice_via_bus(voice: str) -> str:
59+
voice_module = _get_voice_module()
60+
if voice_module is None or voice_module.encoder is None:
61+
raise ValueError("Voice module is not registered on the ModalityBus.")
62+
63+
for cfg in MODELS.values():
64+
if voice in cfg["voices"]:
65+
return voice
66+
67+
raise ValueError(f"Unknown voice '{voice}'. Use /v1/voices to see options.")
68+
69+
70+
def _read_wav_as_mono_float32(raw_wav: bytes) -> tuple[bytes, int]:
71+
import numpy as np
72+
73+
with wave.open(io.BytesIO(raw_wav), "rb") as wav_file:
74+
sample_rate = wav_file.getframerate()
75+
n_channels = wav_file.getnchannels()
76+
sample_width = wav_file.getsampwidth()
77+
frames = wav_file.readframes(wav_file.getnframes())
78+
79+
if sample_width == 2:
80+
audio = np.frombuffer(frames, dtype=np.int16).astype(np.float32) / 32768.0
81+
elif sample_width == 4:
82+
audio = np.frombuffer(frames, dtype=np.int32).astype(np.float32) / 2147483648.0
83+
else:
84+
audio = np.frombuffer(frames, dtype=np.float32)
85+
86+
if n_channels > 1:
87+
audio = audio.reshape(-1, n_channels).mean(axis=1)
88+
89+
return audio.astype(np.float32).tobytes(), sample_rate
90+
91+
92+
_ensure_bus_modules()
93+
3594
# ---------------------------------------------------------------------------
3695
# Job ledger — full lifecycle tracking for every generation
3796
# ---------------------------------------------------------------------------
@@ -138,7 +197,7 @@ def synthesize(req: SynthesizeRequest):
138197
)
139198

140199
try:
141-
resolve_model(req.voice)
200+
req.voice = _resolve_voice_via_bus(req.voice)
142201
except ValueError as e:
143202
_update_job(job_id, {"status": "error", "error": str(e)})
144203
return JSONResponse(status_code=400, content={"error": str(e), "job_id": job_id})
@@ -234,7 +293,7 @@ def audio_speech(req: SpeechRequest):
234293

235294
voice = req.voice
236295
try:
237-
resolve_model(voice)
296+
voice = _resolve_voice_via_bus(voice)
238297
except ValueError:
239298
voice = "af_heart"
240299

@@ -316,12 +375,37 @@ async def vad_check(file: UploadFile):
316375
}
317376
)
318377

319-
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as tmp:
320-
content = await file.read()
321-
tmp.write(content)
322-
tmp.flush()
323-
t_load = time.perf_counter()
324-
result = detect_speech_file(tmp.name)
378+
content = await file.read()
379+
t_load = time.perf_counter()
380+
381+
voice_module = _get_voice_module()
382+
if voice_module is not None and voice_module.gate is not None:
383+
raw_audio, sample_rate = _read_wav_as_mono_float32(content)
384+
with _bus_vad_lock:
385+
gate_result = voice_module.gate.check(raw_audio, sample_rate=sample_rate, sample_width=4)
386+
_bus.perceive(
387+
raw_audio,
388+
modality=ModalityType.VOICE,
389+
channel="http:v1/vad",
390+
sample_rate=sample_rate,
391+
sample_width=4,
392+
transcript="speech detected",
393+
)
394+
395+
class _Result:
396+
has_speech = gate_result.passed
397+
confidence = gate_result.confidence
398+
speech_ratio = gate_result.metadata.get("speech_ratio", 0.0)
399+
num_segments = gate_result.metadata.get("num_segments", 0)
400+
total_speech_sec = gate_result.metadata.get("total_speech_sec", 0.0)
401+
total_audio_sec = gate_result.metadata.get("total_audio_sec", 0.0)
402+
403+
result = _Result()
404+
else:
405+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as tmp:
406+
tmp.write(content)
407+
tmp.flush()
408+
result = detect_speech_file(tmp.name)
325409

326410
t_end = time.perf_counter()
327411
processing_time = t_end - t_start
@@ -443,14 +527,30 @@ def health():
443527
}
444528

445529

530+
@app.get("/diagnostics")
531+
def diagnostics():
532+
"""Diagnostics snapshot with bus state."""
533+
with _jobs_lock:
534+
total = len(_jobs)
535+
active = sum(1 for j in _jobs.values() if j.get("status") in ("generating", "processing"))
536+
return {
537+
"engines_loaded": get_loaded_engines(),
538+
"vad_loaded": vad_loaded(),
539+
"jobs": {
540+
"total": total,
541+
"active": active,
542+
},
543+
"bus": {
544+
"health": _bus.health(),
545+
"hud": _bus.hud(),
546+
},
547+
}
548+
549+
446550
# ---------------------------------------------------------------------------
447551
# Modality Bus endpoints
448552
# ---------------------------------------------------------------------------
449553

450-
_bus = ModalityBus()
451-
_bus.register(TextModule())
452-
_bus.register(VoiceModule())
453-
454554

455555
@app.get("/v1/bus/hud")
456556
def bus_hud():

0 commit comments

Comments
 (0)