Conversation
There was a problem hiding this comment.
Pull request overview
This PR introduces a Gemini Live–based audio streaming session that can fetch relevant “memory” from the vector database via tool calls, adds helper scripts to populate the DB with sample content, and updates tests/dependencies to support the new integration.
Changes:
- Replace the websocket audio pipeline in
main.pyto useGeminiLiveSessioninstead of the prior ASR integration. - Add Gemini tool functions (
save_information,fetch_information) plus unit tests, and add a basic Gemini Live session test. - Add DB population scripts (with and without vectors) and update ASR stability handling + dependency pins.
Reviewed changes
Copilot reviewed 10 out of 12 changed files in this pull request and generated 13 comments.
Show a summary per file
| File | Description |
|---|---|
| src/main.py | Switch websocket handling from StreamingASR to GeminiLiveSession lifecycle. |
| src/gemini_live.py | New Gemini Live session wrapper w/ audio queueing and tool-call handling. |
| src/gemini_tools.py | New DB-backed tool functions for saving/fetching vector “memory”. |
| src/asr.py | Update streaming transcript handling to consider result stability and emit combined partials. |
| src/tests/asr_test.py | Update ASR tests for new stability field + changed partial transcript behavior. |
| src/tests/gemini_tools_test.py | New unit tests for Gemini tool functions. |
| src/tests/gemini_live_test.py | New unit tests for basic GeminiLiveSession behavior/constants. |
| scripts/populate_db.py | New script to populate categories/conversations/vectors for demo/testing. |
| scripts/populate_db_simple.py | New script to populate categories/conversations without vectors. |
| scripts/audio_streaming_helper.py | New helper to simulate streaming raw PCM audio over the websocket. |
| requirements.txt | Dependency bumps and additions (notably google-genai, pytest-asyncio). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
src/asr.py
Outdated
| payload = { | ||
| "status": "partial", "text": self.final_buffer.strip() + " " + partial_text | ||
| } | ||
| else: | ||
| continue | ||
| # pylint: disable=unspecified-encoding | ||
| with open('./scripts/response.txt', 'a') as file: | ||
| file.write(str(self.final_buffer.strip() + ' ' + partial_text) + '\n') |
There was a problem hiding this comment.
The partial transcript payload concatenation always inserts a leading space (self.final_buffer.strip() + " " + partial_text). When final_buffer is empty this produces outputs like " hello". Build the string conditionally (only add the space if there is already committed text), so UI consumers don’t have to trim/handle formatting artifacts.
| payload = { | |
| "status": "partial", "text": self.final_buffer.strip() + " " + partial_text | |
| } | |
| else: | |
| continue | |
| # pylint: disable=unspecified-encoding | |
| with open('./scripts/response.txt', 'a') as file: | |
| file.write(str(self.final_buffer.strip() + ' ' + partial_text) + '\n') | |
| committed = self.final_buffer.strip() | |
| if committed: | |
| combined_text = committed + " " + partial_text | |
| else: | |
| combined_text = partial_text | |
| payload = { | |
| "status": "partial", | |
| "text": combined_text, | |
| } | |
| else: | |
| continue | |
| # pylint: disable=unspecified-encoding | |
| with open('./scripts/response.txt', 'a') as file: | |
| file.write(str(payload["text"]) + '\n') |
src/tests/asr_test.py
Outdated
| asr = StreamingASR(ws, testing=True, client=client) | ||
| asr._worker() # pylint: disable=protected-access | ||
| assert ws.sent[0] == {"type": "transcript", "data": {"status": "partial", "text": "hello"}} | ||
| assert ws.sent[0] == {"type": "transcript", "data": {"status": "partial", "text": " hello"}} |
There was a problem hiding this comment.
This test is asserting a leading space in the partial transcript (" hello"), which is an artifact of the current concatenation logic rather than desired behavior. It would be better to fix the production code to avoid the leading space and keep the test expectation trimmed (e.g., "hello").
| assert ws.sent[0] == {"type": "transcript", "data": {"status": "partial", "text": " hello"}} | |
| assert ws.sent[0] == {"type": "transcript", "data": {"status": "partial", "text": "hello"}} |
src/asr.py
Outdated
| # pylint: disable=unspecified-encoding | ||
| with open('./scripts/response.txt', 'a') as file: | ||
| file.write(str(self.final_buffer.strip() + ' ' + partial_text) + '\n') |
There was a problem hiding this comment.
StreamingASR._worker() now writes every partial/final update to ./scripts/response.txt. This introduces unexpected filesystem I/O in the main ASR path, can fail in containerized/prod environments (missing relative path / permissions), and can become a bottleneck. If this is for debugging, gate it behind testing/an explicit debug flag or remove it from the core module.
| # pylint: disable=unspecified-encoding | |
| with open('./scripts/response.txt', 'a') as file: | |
| file.write(str(self.final_buffer.strip() + ' ' + partial_text) + '\n') | |
| # Only write to a local debug file when running in testing mode. | |
| if self.testing: | |
| # pylint: disable=unspecified-encoding | |
| try: | |
| with open('./scripts/response.txt', 'a') as file: | |
| file.write(str(self.final_buffer.strip() + ' ' + partial_text) + '\n') | |
| except OSError: | |
| # Ignore file I/O issues in testing/debug mode. | |
| pass |
| if fc.name == "fetch_information": | ||
| query = fc.args.get("query", "") | ||
| print(f"fetching information for query: {query!r}") | ||
| tool_result = fetch_information(query) |
There was a problem hiding this comment.
fetch_information() / vector search is synchronous and can involve network calls (Vertex embeddings + DB query). Calling it directly inside the async receive loop will block the event loop and can delay audio send/receive and websocket delivery. Offload the tool execution to a thread (asyncio.to_thread) or make the tool function async so the event loop stays responsive.
| tool_result = fetch_information(query) | |
| tool_result = await asyncio.to_thread(fetch_information, query) |
| async def stop(self): | ||
| try: | ||
| self._audio_queue.put_nowait(None) | ||
| except asyncio.QueueFull: | ||
| pass | ||
| if self._task: | ||
| self._task.cancel() | ||
| try: | ||
| await self._task | ||
| except (asyncio.CancelledError, Exception): # pylint: disable=broad-except | ||
| pass | ||
| print(f"session total tokens: {self.tokens_used}") | ||
|
|
||
| async def _run(self): | ||
| # wait for first audio chunk before opening the connection | ||
| first_chunk = await self._audio_queue.get() | ||
| if first_chunk is None: | ||
| return # stopped before any audio arrived | ||
| client = genai.Client() | ||
| try: | ||
| async with client.aio.live.connect(model=MODEL, config=CONFIG) as session: | ||
| print("Gemini Live connected") | ||
| await session.send_realtime_input( | ||
| audio={"data": first_chunk, "mime_type": "audio/pcm;rate=16000"} | ||
| ) | ||
| send_task = asyncio.create_task(self._send(session)) | ||
| recv_task = asyncio.create_task(self._receive(session)) | ||
| await send_task | ||
| recv_task.cancel() | ||
| try: | ||
| await recv_task | ||
| except asyncio.CancelledError: | ||
| pass | ||
| except Exception as e: # pylint: disable=broad-except | ||
| print(f"Gemini Live error: {e}") | ||
| finally: | ||
| await client.aio.aclose() |
There was a problem hiding this comment.
stop() cancels only the top-level _run task. If cancellation happens after _send/_receive tasks are created, those tasks may be left running (and _send may also block forever if the stop sentinel couldn’t be enqueued due to QueueFull). Track the child tasks and cancel them in a finally, and ensure stop reliably unblocks _send (e.g., by draining one item and enqueuing None, or by cancelling _send).
src/main.py
Outdated
| geminiLive = None | ||
|
|
||
|
|
||
| @app.websocket("/ws/") | ||
| async def audio_ws(ws: WebSocket): | ||
| global geminiLive | ||
| await ws.accept() | ||
| await ws.send_json({"type": "control", "cmd": "ready"}) | ||
| while True: | ||
| msg = await ws.receive() | ||
| if msg["type"] == "websocket.disconnect": | ||
| break | ||
| if msg["type"] == "websocket.receive": | ||
| if "bytes" in msg: # audio tulee binäärinä | ||
| if not asr: | ||
| await ws.send_json({"type": "error", "message": "ASR not started"}) | ||
| print("Received audio chunk but ASR not started") | ||
| continue | ||
| asr.push_audio(msg["bytes"]) | ||
| elif "text" in msg: # kaikki muu kuin audio tulee tekstinä | ||
| await handle_text(msg["text"], ws) | ||
| try: | ||
| while True: | ||
| msg = await ws.receive() | ||
| if msg["type"] == "websocket.disconnect": | ||
| break | ||
| if msg["type"] == "websocket.receive": | ||
| if "bytes" in msg: # audio tulee binäärinä | ||
| if not geminiLive: | ||
| await ws.send_json({"type": "error", "message": "ASR not started"}) | ||
| print("Received audio chunk but ASR not started") | ||
| continue | ||
| geminiLive.push_audio(msg["bytes"]) | ||
| elif "text" in msg: # kaikki muu kuin audio tulee tekstinä | ||
| await handle_text(msg["text"], ws) | ||
| finally: | ||
| if geminiLive: | ||
| await geminiLive.stop() | ||
| geminiLive = None |
There was a problem hiding this comment.
geminiLive is stored as a module-level global and reused across websocket connections. This will cause cross-talk between clients (one client can stop another client’s session, audio chunks can be routed to the wrong session) and is not safe under concurrent connections. Keep the session per-connection instead (e.g., create it as a local variable inside audio_ws, or attach it to ws.state) and remove the global.
| # pylint: disable=invalid-name global-statement | ||
| app = FastAPI() | ||
| asr = None # pylint: disable=invalid-name | ||
| geminiLive = None | ||
|
|
There was a problem hiding this comment.
The identifier geminiLive uses camelCase and suppresses invalid-name, while the rest of the codebase (e.g., previous asr global) uses snake_case. Renaming to gemini_live/gemini_live_session (and dropping the broad pylint disable) would keep naming consistent and avoid hiding other invalid-name issues.
|
|
||
|
|
||
| CONFIG = genai.types.LiveConnectConfig( | ||
| response_modalities=["AUDIO"], |
There was a problem hiding this comment.
LiveConnectConfig is configured with response_modalities=["AUDIO"], but the system instruction explicitly says “Do not generate audio.” This mismatch can lead to unintended audio output and extra token/bandwidth usage. Align the config with the intended behavior (e.g., use text-only responses, or disable response modalities if the API supports it).
| response_modalities=["AUDIO"], | |
| response_modalities=["TEXT"], |
| print(f"fetch result: {tool_result}") | ||
| await self.ws.send_json( | ||
| {"type": "ai", "data": tool_result["information"]} | ||
| ) |
There was a problem hiding this comment.
Tool calls are executed locally, but the tool result is never sent back to the Gemini Live session. In most tool-calling flows, the model will wait for a tool response event/message to continue the turn; sending only to the websocket client won’t satisfy the tool call. Add the appropriate SDK call to return the function result to session (in addition to any UI message you want to emit).
| print(f"fetch result: {tool_result}") | |
| await self.ws.send_json( | |
| {"type": "ai", "data": tool_result["information"]} | |
| ) | |
| print(f"fetch result: {tool_result}") | |
| # Send result to the websocket client UI. | |
| await self.ws.send_json( | |
| {"type": "ai", "data": tool_result["information"]} | |
| ) | |
| # Also send the tool result back to the Gemini Live session | |
| # so the model can continue the turn. | |
| try: | |
| await session.send_tool_response( | |
| name=fc.name, | |
| call_id=getattr(fc, "id", None), | |
| result=tool_result, | |
| ) | |
| except AttributeError: | |
| # Fallback if the SDK uses a different attribute name | |
| # or structure for identifying the tool call. | |
| await session.send_tool_response( | |
| name=fc.name, | |
| call_id=None, | |
| result=tool_result, | |
| ) |
| await self.ws.send_json( | ||
| {"type": "ai", "data": tool_result["information"]} | ||
| ) |
There was a problem hiding this comment.
tool_result["information"] is accessed unconditionally, but fetch_information() can return an error dict without an information key (e.g., empty query or exceptions). This will raise KeyError and break the receive loop. Handle the error case explicitly (check status, or use .get() with a fallback) before sending to the websocket.
| await self.ws.send_json( | |
| {"type": "ai", "data": tool_result["information"]} | |
| ) | |
| # Safely handle possible error dicts or missing "information" key | |
| information = None | |
| if isinstance(tool_result, dict): | |
| information = tool_result.get("information") | |
| if information is None: | |
| # Fallback: use explicit error message if available, otherwise a generic one | |
| fallback_msg = "" | |
| if isinstance(tool_result, dict): | |
| fallback_msg = tool_result.get("error") or "" | |
| if not fallback_msg: | |
| fallback_msg = "No information available." | |
| await self.ws.send_json( | |
| {"type": "ai", "data": fallback_msg} | |
| ) | |
| else: | |
| await self.ws.send_json( | |
| {"type": "ai", "data": information} | |
| ) |
Audiostream AI-integration with relevant information fetching from the database via tool call.