diff --git a/src/ai/agents/ai_caller.py b/src/ai/agents/ai_caller.py index 816a8f9..6042eca 100644 --- a/src/ai/agents/ai_caller.py +++ b/src/ai/agents/ai_caller.py @@ -229,7 +229,7 @@ def call_with_provider( elif provider_lower == PROVIDER_ANTHROPIC: result = self._call_anthropic(model, system_message, prompt, temperature) elif provider_lower == PROVIDER_OLLAMA: - result = self._call_ollama(system_message, prompt, temperature) + result = self._call_ollama(system_message, prompt, temperature, model) elif provider_lower == PROVIDER_GEMINI: result = self._call_gemini(model, system_message, prompt, temperature) elif provider_lower == PROVIDER_GROQ: diff --git a/src/ai/chat_context_mixin.py b/src/ai/chat_context_mixin.py index c208aa9..cfcf09a 100644 --- a/src/ai/chat_context_mixin.py +++ b/src/ai/chat_context_mixin.py @@ -113,9 +113,11 @@ def _construct_prompt(self, user_message: str, context_data: 'ChatContextData') ]) # Add conversation history (last few exchanges) - if self.conversation_history: + with self._history_lock: + history_snapshot = list(self.conversation_history[-6:]) + if history_snapshot: prompt_parts.append("Recent Conversation:") - for item in self.conversation_history[-6:]: # Last 3 exchanges + for item in history_snapshot: # Last 3 exchanges role = item["role"].title() message = item["message"][:200] + "..." if len(item["message"]) > 200 else item["message"] prompt_parts.append(f"{role}: {message}") @@ -132,31 +134,35 @@ def _construct_prompt(self, user_message: str, context_data: 'ChatContextData') def _add_to_history(self, role: str, message: str): """Add a message to conversation history.""" - self.conversation_history.append({ - "role": role, - "message": message, - "timestamp": datetime.now().isoformat() - }) + with self._history_lock: + self.conversation_history.append({ + "role": role, + "message": message, + "timestamp": datetime.now().isoformat() + }) - # Keep only recent history - if len(self.conversation_history) > self.max_history_items: - self.conversation_history = self.conversation_history[-self.max_history_items:] + # Keep only recent history + if len(self.conversation_history) > self.max_history_items: + self.conversation_history = self.conversation_history[-self.max_history_items:] def clear_history(self): """Clear conversation history.""" - self.conversation_history = [] + with self._history_lock: + self.conversation_history = [] logger.info("Chat conversation history cleared") def get_history(self) -> list: """Get conversation history.""" - return self.conversation_history.copy() + with self._history_lock: + return self.conversation_history.copy() def get_context_from_history(self, max_entries: int = 5) -> str: """Get context from recent conversation history.""" - if not self.conversation_history: - return "" + with self._history_lock: + if not self.conversation_history: + return "" + recent_history = list(self.conversation_history[-max_entries:]) - recent_history = self.conversation_history[-max_entries:] context_parts = [] for entry in recent_history: diff --git a/src/ai/chat_processor.py b/src/ai/chat_processor.py index 2cdec37..e0e5570 100644 --- a/src/ai/chat_processor.py +++ b/src/ai/chat_processor.py @@ -94,6 +94,7 @@ def __init__(self, app): self.app = app self.is_processing = False self.conversation_history = [] + self._history_lock = threading.Lock() # Typing indicator state self._typing_indicator_mark = None diff --git a/src/ai/mcp/mcp_manager.py b/src/ai/mcp/mcp_manager.py index b19b71f..04a5bec 100644 --- a/src/ai/mcp/mcp_manager.py +++ b/src/ai/mcp/mcp_manager.py @@ -78,13 +78,14 @@ def __init__(self, process: subprocess.Popen, server: MCPServer = None): self.process = process self.server = server # Reference to store errors self.request_id = 0 - self.response_queue = queue.Queue() + self._pending_requests: Dict[int, queue.Queue] = {} + self._pending_lock = threading.Lock() self.reader_thread = threading.Thread(target=self._read_responses, daemon=True) self.reader_thread.start() # Start stderr reader to capture errors self.stderr_thread = threading.Thread(target=self._read_stderr, daemon=True) self.stderr_thread.start() - + def _read_responses(self): """Read responses from the MCP server stdout""" try: @@ -100,14 +101,23 @@ def _read_responses(self): try: response = json.loads(line) logger.debug(f"MCP response: {response}") - self.response_queue.put(response) + resp_id = response.get("id") + with self._pending_lock: + req_queue = self._pending_requests.get(resp_id) + if req_queue is not None: + req_queue.put(response) + else: + logger.warning(f"Received response for unknown request id {resp_id}, discarding") except json.JSONDecodeError: logger.warning(f"Invalid JSON from MCP server: {line}") except Exception as e: logger.error(f"Error reading MCP responses: {e}") finally: - # Notify any waiters that the reader has exited so they don't hang forever - self.response_queue.put({"error": {"code": -1, "message": "MCP reader thread exited"}}) + # Notify all pending waiters that the reader has exited so they don't hang forever + exit_msg = {"error": {"code": -1, "message": "MCP reader thread exited"}} + with self._pending_lock: + for req_queue in self._pending_requests.values(): + req_queue.put(exit_msg) def _read_stderr(self): """Read stderr output from the MCP server""" @@ -135,35 +145,41 @@ def send_request(self, method: str, params: Dict[str, Any] = None, timeout: floa # Check if process is still running if self.process.poll() is not None: raise Exception(f"MCP server process has terminated with code: {self.process.returncode}") - - self.request_id += 1 - request = { - "jsonrpc": "2.0", - "method": method, - "id": self.request_id - } - if params: - request["params"] = params - - # Send request - request_str = json.dumps(request) + "\n" - logger.debug(f"Sending MCP request: {request_str.strip()}") + + # Allocate a per-request queue and atomically assign request_id + req_queue = queue.Queue() + with self._pending_lock: + self.request_id += 1 + current_id = self.request_id + self._pending_requests[current_id] = req_queue + try: - self.process.stdin.write(request_str) - self.process.stdin.flush() - except Exception as e: - raise Exception(f"Failed to send request to MCP server: {e}") - - # Wait for response - start_time = time.time() - while time.time() - start_time < timeout: - # Check if process terminated while waiting - if self.process.poll() is not None: - raise Exception(f"MCP server process terminated while waiting for response (code: {self.process.returncode})") - + request = { + "jsonrpc": "2.0", + "method": method, + "id": current_id + } + if params: + request["params"] = params + + # Send request + request_str = json.dumps(request) + "\n" + logger.debug(f"Sending MCP request: {request_str.strip()}") try: - response = self.response_queue.get(timeout=0.1) - if response.get("id") == self.request_id: + self.process.stdin.write(request_str) + self.process.stdin.flush() + except Exception as e: + raise Exception(f"Failed to send request to MCP server: {e}") + + # Wait for response on our dedicated queue + start_time = time.time() + while time.time() - start_time < timeout: + # Check if process terminated while waiting + if self.process.poll() is not None: + raise Exception(f"MCP server process terminated while waiting for response (code: {self.process.returncode})") + + try: + response = req_queue.get(timeout=0.1) if "error" in response: error = response['error'] # Extract error details @@ -171,7 +187,7 @@ def send_request(self, method: str, params: Dict[str, Any] = None, timeout: floa code = error.get('code', 'Unknown') message = error.get('message', 'Unknown error') data = error.get('data', {}) - + # Check for rate limit error if code == 429 or 'rate' in str(message).lower(): retry_after = data.get('retry_after', 60) @@ -181,10 +197,14 @@ def send_request(self, method: str, params: Dict[str, Any] = None, timeout: floa else: raise Exception(f"MCP error: {error}") return response.get("result") - except queue.Empty: - continue - - raise Exception(f"Timeout waiting for MCP response to {method}") + except queue.Empty: + continue + + raise Exception(f"Timeout waiting for MCP response to {method}") + finally: + # Always clean up the per-request queue + with self._pending_lock: + self._pending_requests.pop(current_id, None) class MCPManager: @@ -330,27 +350,30 @@ def _discover_tools(self, protocol: MCPProtocol) -> List[Dict[str, Any]]: def execute_tool(self, server_name: str, tool_name: str, arguments: Dict[str, Any]) -> Any: """Execute a tool on an MCP server""" + # Hold lock only for server/protocol lookup with self._lock: server = self.servers.get(server_name) if not server or not server.process: raise Exception(f"MCP server {server_name} is not running") - + if not server.protocol: raise Exception(f"MCP server {server_name} has no active protocol") - - try: - # Log the request for debugging - request_params = { - "name": tool_name, - "arguments": arguments - } - logger.debug(f"Sending tool call request: {request_params}") - - result = server.protocol.send_request("tools/call", request_params) - return result - except Exception as e: - logger.error(f"Failed to execute tool {tool_name} on {server_name}: {e}") - raise + + protocol = server.protocol + + # Perform the (potentially long) I/O outside the lock + try: + request_params = { + "name": tool_name, + "arguments": arguments + } + logger.debug(f"Sending tool call request: {request_params}") + + result = protocol.send_request("tools/call", request_params) + return result + except Exception as e: + logger.error(f"Failed to execute tool {tool_name} on {server_name}: {e}") + raise def get_all_tools(self) -> List[Tuple[str, Dict[str, Any]]]: """Get all available tools from all running servers""" @@ -475,6 +498,9 @@ def _monitor_loop(self) -> None: def _check_servers(self) -> None: """Check all servers and restart any that have crashed.""" + crashed_servers = [] + + # Hold lock only briefly to detect crashed servers and clean up stale refs with self.mcp_manager._lock: for name, server in list(self.mcp_manager.servers.items()): if not server.enabled: @@ -485,17 +511,25 @@ def _check_servers(self) -> None: continue if server.process.poll() is not None: - # Process has terminated + # Process has terminated — collect name and clean up stale refs exit_code = server.process.returncode logger.warning(f"MCP server '{name}' crashed with exit code {exit_code}") - self._handle_crash(name, server) + server.process = None + server.protocol = None + server.tools = None + crashed_servers.append(name) - def _handle_crash(self, name: str, server: MCPServer) -> None: + # Handle restarts outside the lock + for name in crashed_servers: + self._handle_crash(name) + + def _handle_crash(self, name: str) -> None: """Handle a crashed server with restart logic. + Does NOT assume any lock is held. start_server acquires its own lock. + Args: name: Server name - server: Server instance """ attempts = self.restart_attempts.get(name, 0) @@ -507,26 +541,16 @@ def _handle_crash(self, name: str, server: MCPServer) -> None: backoff = 2 ** attempts logger.info(f"Attempting to restart MCP server '{name}' in {backoff}s (attempt {attempts + 1}/{self.max_restarts})") - # Clean up old process - server.process = None - server.protocol = None - server.tools = None - - # Wait with backoff + # Wait with backoff (no lock held, so this is safe) time.sleep(backoff) if not self._running: return try: - # Release lock temporarily for restart - self.mcp_manager._lock.release() - try: - self.mcp_manager.start_server(name) - self.restart_attempts[name] = 0 # Reset on success - logger.info(f"Successfully restarted MCP server '{name}'") - finally: - self.mcp_manager._lock.acquire() + self.mcp_manager.start_server(name) + self.restart_attempts[name] = 0 # Reset on success + logger.info(f"Successfully restarted MCP server '{name}'") except Exception as e: self.restart_attempts[name] = attempts + 1 logger.error(f"Failed to restart MCP server '{name}': {e}") diff --git a/src/ai/mcp/mcp_tool_wrapper.py b/src/ai/mcp/mcp_tool_wrapper.py index e8271d8..043600c 100644 --- a/src/ai/mcp/mcp_tool_wrapper.py +++ b/src/ai/mcp/mcp_tool_wrapper.py @@ -233,36 +233,23 @@ def execute(self, **kwargs) -> ToolResult: # Check rate limits for this server if self.server_name in RATE_LIMITS: - with rate_limit_lock: - rate_config = RATE_LIMITS[self.server_name] - min_interval = rate_config.get("minimum_interval", 1.0 / rate_config["requests_per_second"]) - - current_time = time.time() - - # Check global rate limit for this server - global_last = rate_config.get("global_last_request", 0) - global_time_since_last = current_time - global_last - - if global_time_since_last < min_interval: - wait_time = min_interval - global_time_since_last - logger.info(f"Global rate limiting for {self.server_name}: waiting {wait_time:.2f}s") - time.sleep(wait_time) + tool_key = f"{self.server_name}:{self.original_name}" + while True: + with rate_limit_lock: + rate_config = RATE_LIMITS[self.server_name] + min_interval = rate_config.get("minimum_interval", 1.0 / rate_config["requests_per_second"]) current_time = time.time() - - # Also check per-tool rate limit - tool_key = f"{self.server_name}:{self.original_name}" - last_time = rate_config["last_request_time"].get(tool_key, 0) - time_since_last = current_time - last_time - - # If not enough time has passed for this specific tool, wait - if time_since_last < min_interval: - wait_time = min_interval - time_since_last - logger.info(f"Tool rate limiting: waiting {wait_time:.2f}s before calling {self.original_name}") - time.sleep(wait_time) - - # Update last request times - rate_config["global_last_request"] = time.time() - rate_config["last_request_time"][tool_key] = time.time() + global_last = rate_config.get("global_last_request", 0) + global_wait = min_interval - (current_time - global_last) + last_time = rate_config["last_request_time"].get(tool_key, 0) + tool_wait = min_interval - (current_time - last_time) + wait_time = max(global_wait, tool_wait, 0.0) + if wait_time <= 0: + rate_config["global_last_request"] = current_time + rate_config["last_request_time"][tool_key] = current_time + break + logger.info(f"Rate limiting for {self.server_name}: waiting {wait_time:.2f}s") + time.sleep(wait_time) # Execute the tool via MCP with retry logic logger.info(f"Executing MCP tool {self.original_name} on server {self.server_name}") diff --git a/src/ai/providers/ollama_provider.py b/src/ai/providers/ollama_provider.py index 0ffdf2a..69152ee 100644 --- a/src/ai/providers/ollama_provider.py +++ b/src/ai/providers/ollama_provider.py @@ -54,13 +54,14 @@ def _get_first_available_model(session, base_url: str) -> str: return "" -def call_ollama(system_message: str, prompt: str, temperature: float) -> AIResult: +def call_ollama(system_message: str, prompt: str, temperature: float, model: str = "") -> AIResult: """Call local Ollama API to generate a response. Args: system_message: System message to guide the AI's response prompt: User prompt temperature: Temperature parameter (0.0 to 1.0) + model: Optional model name override. If provided, skips settings lookup. Returns: AIResult: Type-safe result wrapper. Use result.text for content, @@ -79,11 +80,12 @@ def call_ollama(system_message: str, prompt: str, temperature: float) -> AIResul ollama_url = get_ollama_url() base_url = ollama_url.rstrip("/") # Remove trailing slash if present - # Get model from settings based on the task - model_key = get_model_key_for_task(system_message, prompt) - model = settings_manager.get_nested(f"{model_key}.ollama_model", "") + # Use explicitly passed model, or fall back to settings-based resolution + if not model: + model_key = get_model_key_for_task(system_message, prompt) + model = settings_manager.get_nested(f"{model_key}.ollama_model", "") - # If no task-specific model configured, use global default or auto-detect + # If still no model, use global default or auto-detect if not model: model = settings_manager.get("ollama_default_model", "") if not model: diff --git a/src/ai/soap_processor.py b/src/ai/soap_processor.py index 5f801b0..b5d5f47 100644 --- a/src/ai/soap_processor.py +++ b/src/ai/soap_processor.py @@ -237,7 +237,6 @@ def task(): # Log success with transcript preview to verify speaker labels logger.info(f"Successfully transcribed audio, length: {len(transcript)} chars") - logger.info(f"Transcript preview: {repr(transcript[:200])}") # Update transcript tab with the raw transcript schedule_ui_update(self.app, lambda: [ @@ -319,6 +318,16 @@ def finalize_soap(): schedule_ui_update(self.app, lambda w=icd_warnings: self.app.document_generators._run_icd_validation_to_panel(w)) + # Medication QA: compare transcript medications against SOAP note + try: + from processing.soap_qa import compare_medications + soap_qa_warnings = compare_medications(transcript, soap_note) + except Exception as e: + logger.error(f"SOAP QA comparison failed: {e}") + soap_qa_warnings = [] + schedule_ui_update(self.app, lambda w=soap_qa_warnings: + self.app.document_generators._run_soap_qa_to_panel(w)) + # Display emotion data in panel if available if emotion_data: schedule_ui_update(self.app, lambda ed=emotion_data: diff --git a/src/audio/periodic_analysis.py b/src/audio/periodic_analysis.py index b936133..624cb20 100644 --- a/src/audio/periodic_analysis.py +++ b/src/audio/periodic_analysis.py @@ -162,7 +162,8 @@ def get_combined_history_text(self) -> str: header = f"Analysis #{entry['analysis_number']} (recording time: {formatted_time})" parts.append(f"{header}\n{entry['result_text']}") - return "\n\n" + "─" * 50 + "\n\n".join(parts) + separator = "\n\n" + "─" * 50 + "\n\n" + return separator.join(parts) def clear_history(self) -> None: """Clear the analysis history.""" diff --git a/src/core/app_initializer.py b/src/core/app_initializer.py index a392d3c..02fee70 100644 --- a/src/core/app_initializer.py +++ b/src/core/app_initializer.py @@ -927,6 +927,16 @@ def wrapper(): self.app.after(100, _safe_run("medication_analysis", self.app.document_generators._run_medication_to_panel, soap_note)) self.app.after(200, _safe_run("diagnostic_analysis", self.app.document_generators._run_diagnostic_to_panel, soap_note)) self.app.after(300, _safe_run("compliance_analysis", self.app.document_generators._run_compliance_to_panel, soap_note)) + + # Medication QA: compare transcript medications against SOAP note + if transcript: + try: + from processing.soap_qa import compare_medications + soap_qa_warnings = compare_medications(transcript, soap_note) + except Exception as e: + logger.error(f"SOAP QA comparison failed: {e}") + soap_qa_warnings = [] + self.app.after(400, _safe_run("soap_qa", self.app.document_generators._run_soap_qa_to_panel, soap_qa_warnings)) else: skipped_updates.append('soap') logger.info(f"Skipped SOAP update - user has modified content") diff --git a/src/core/controllers/processing_controller.py b/src/core/controllers/processing_controller.py index 02de60f..01b755d 100644 --- a/src/core/controllers/processing_controller.py +++ b/src/core/controllers/processing_controller.py @@ -542,8 +542,13 @@ def append_text(self, text: str) -> None: if (self.app.capitalize_next or not current or current[-1] in ".!?") and text: text = text[0].upper() + text[1:] self.app.capitalize_next = False - self.app.transcript_text.insert(tk.END, (" " if current and current[-1] != "\n" else "") + text) - self.app.text_chunks.append(f"chunk_{len(self.app.text_chunks)}") + insert_text = (" " if current and current[-1] != "\n" else "") + text + start = self.app.transcript_text.index(tk.END + "-1c") + self.app.transcript_text.insert(tk.END, insert_text) + end = self.app.transcript_text.index(tk.END + "-1c") + tag_name = f"chunk_{len(self.app.text_chunks)}" + self.app.transcript_text.tag_add(tag_name, start, end) + self.app.text_chunks.append(tag_name) self.app.transcript_text.see(tk.END) def scratch_that(self) -> None: diff --git a/src/core/handlers/periodic_analysis_handler.py b/src/core/handlers/periodic_analysis_handler.py index 920cc94..019e4a9 100644 --- a/src/core/handlers/periodic_analysis_handler.py +++ b/src/core/handlers/periodic_analysis_handler.py @@ -55,6 +55,7 @@ def __init__(self, app: 'MedicalDictationApp'): def _get_database(self) -> Database: """Get or create database connection.""" if self._db is None: + logger.warning(f"{self.__class__.__name__}: creating local Database instance — prefer passing db= to constructor") self._db = Database() return self._db diff --git a/src/managers/autosave_manager.py b/src/managers/autosave_manager.py index fc7a9ee..79f96e0 100644 --- a/src/managers/autosave_manager.py +++ b/src/managers/autosave_manager.py @@ -376,18 +376,45 @@ class AutoSaveDataProvider: @staticmethod def create_text_widget_provider(widget, name: str) -> Callable[[], Dict[str, Any]]: - """Create a provider for a text widget.""" + """Create a provider for a text widget. Thread-safe.""" + def _read_widget() -> Dict[str, Any]: + return { + "name": name, + "content": widget.get("1.0", "end-1c"), + "cursor_position": widget.index("insert"), + } + + _empty = {"name": name, "content": "", "cursor_position": "1.0"} + def provider(): + if threading.current_thread() is threading.main_thread(): + try: + return _read_widget() + except (tk.TclError, AttributeError, RuntimeError): + return dict(_empty) + + result_holder: Dict[str, Any] = {} + event = threading.Event() + + def _on_main(): + try: + result_holder["data"] = _read_widget() + except (tk.TclError, AttributeError, RuntimeError): + result_holder["data"] = dict(_empty) + finally: + event.set() + try: - return { - "name": name, - "content": widget.get("1.0", "end-1c"), - "cursor_position": widget.index("insert") - } - except (tk.TclError, AttributeError, RuntimeError): - # Widget may be destroyed or not initialized - return {"name": name, "content": "", "cursor_position": "1.0"} - + widget.after(0, _on_main) + except (tk.TclError, RuntimeError): + return dict(_empty) + + if not event.wait(timeout=5.0): + logger.warning(f"Timeout reading widget '{name}' on main thread") + return dict(_empty) + + return result_holder["data"] + return provider @staticmethod diff --git a/src/managers/notification_manager.py b/src/managers/notification_manager.py index 74157ae..768b566 100644 --- a/src/managers/notification_manager.py +++ b/src/managers/notification_manager.py @@ -33,7 +33,9 @@ def __init__(self, app: Any) -> None: self.notification_queue: Queue[Dict[str, Any]] = Queue() self.active_toasts: List[tk.Toplevel] = [] self.notification_history: List[Dict[str, Any]] = [] - + self._history_lock = threading.Lock() + self._running = True + # Start notification processor thread self.processor_thread = threading.Thread(target=self._process_notifications, daemon=True) self.processor_thread.start() @@ -60,9 +62,12 @@ def show_completion(self, patient_name: str, recording_id: int, } self.notification_queue.put(notification) - self.notification_history.append(notification) - - logger.info(f"Queued completion notification for patient: {patient_name}") + with self._history_lock: + self.notification_history.append(notification) + if len(self.notification_history) > 200: + self.notification_history = self.notification_history[-200:] + + logger.info("Queued completion notification", patient_name=patient_name) def show_error(self, patient_name: str, error_message: str, recording_id: int, task_id: str): @@ -84,9 +89,12 @@ def show_error(self, patient_name: str, error_message: str, } self.notification_queue.put(notification) - self.notification_history.append(notification) - - logger.error(f"Queued error notification for patient: {patient_name} - {error_message}") + with self._history_lock: + self.notification_history.append(notification) + if len(self.notification_history) > 200: + self.notification_history = self.notification_history[-200:] + + logger.error("Queued error notification", patient_name=patient_name, error_message=error_message) def show_progress(self, patient_name: str, progress: int, task_id: str): """Show progress notification. @@ -110,14 +118,16 @@ def show_progress(self, patient_name: str, progress: int, task_id: str): def _process_notifications(self) -> None: """Process notifications in background thread.""" - while True: + while self._running: try: # Get notification with timeout notification = self.notification_queue.get(timeout=1.0) - + if not self._running: + break + # Schedule notification on main thread self.app.after(0, lambda n=notification: self._display_notification(n)) - + except Empty: continue except Exception as e: @@ -220,10 +230,13 @@ def _position_toast(self, toast: tk.Toplevel) -> None: def _fade_in(self, window: tk.Toplevel, alpha: float = 0.0) -> None: """Fade in animation for toast.""" - if alpha < 0.9: - alpha += 0.1 - window.attributes("-alpha", alpha) - window.after(20, lambda: self._fade_in(window, alpha)) + try: + if alpha < 0.9: + alpha += 0.1 + window.attributes("-alpha", alpha) + window.after(20, lambda: self._fade_in(window, alpha)) + except tk.TclError: + return def _hide_toast(self, toast: tk.Toplevel) -> None: """Hide and destroy toast notification.""" @@ -235,12 +248,15 @@ def _hide_toast(self, toast: tk.Toplevel) -> None: def _fade_out(self, window: tk.Toplevel, alpha: float = 0.9) -> None: """Fade out animation for toast.""" - if alpha > 0.1: - alpha -= 0.1 - window.attributes("-alpha", alpha) - window.after(20, lambda: self._fade_out(window, alpha)) - else: - window.destroy() + try: + if alpha > 0.1: + alpha -= 0.1 + window.attributes("-alpha", alpha) + window.after(20, lambda: self._fade_out(window, alpha)) + else: + window.destroy() + except tk.TclError: + return def _show_statusbar_notification(self, notification: Dict[str, Any]) -> None: """Show notification in status bar.""" @@ -276,11 +292,13 @@ def get_notification_history(self, limit: int = 50) -> List[Dict[str, Any]]: Returns: List of recent notifications """ - return self.notification_history[-limit:] + with self._history_lock: + return self.notification_history[-limit:] def clear_notification_history(self) -> None: """Clear notification history.""" - self.notification_history.clear() + with self._history_lock: + self.notification_history.clear() logger.info("Notification history cleared") def show_queue_status(self, active_count: int, completed_count: int, failed_count: int) -> None: @@ -340,10 +358,16 @@ def _view_recording(self, recording_id: int, toast: tk.Toplevel) -> None: def cleanup(self) -> None: """Clean up notification manager resources.""" + self._running = False + self.processor_thread.join(timeout=3.0) + # Clear any remaining toasts for toast in self.active_toasts: - if toast.winfo_exists(): - toast.destroy() + try: + if toast.winfo_exists(): + toast.destroy() + except tk.TclError: + pass self.active_toasts.clear() - + logger.info("NotificationManager cleaned up") \ No newline at end of file diff --git a/src/managers/tts_manager.py b/src/managers/tts_manager.py index 59245b6..f4dacc2 100644 --- a/src/managers/tts_manager.py +++ b/src/managers/tts_manager.py @@ -14,6 +14,7 @@ pygame = None PYGAME_AVAILABLE = False +import os import threading from typing import Optional, Dict, Any, List from pydub import AudioSegment @@ -244,14 +245,26 @@ def _play_audio_blocking(self, audio: AudioSegment, output_device: str = None): # Export to temporary file for pygame import tempfile self.logger.info("Playing with pygame (no output device specified)") - with tempfile.NamedTemporaryFile(suffix='.mp3', delete=True) as temp: - audio.export(temp.name, format='mp3') - pygame.mixer.music.load(temp.name) + temp = tempfile.NamedTemporaryFile(suffix='.mp3', delete=False) + temp_path = temp.name + temp.close() + try: + audio.export(temp_path, format='mp3') + pygame.mixer.music.load(temp_path) pygame.mixer.music.play() # Wait for playback to complete while pygame.mixer.music.get_busy(): pygame.time.Clock().tick(10) + finally: + try: + pygame.mixer.music.unload() + except Exception: + pass + try: + os.unlink(temp_path) + except OSError: + pass self.logger.info("Pygame playback completed") else: # Use pydub playback @@ -265,13 +278,18 @@ def _play_audio_blocking(self, audio: AudioSegment, output_device: str = None): def _play_audio_async(self, audio: AudioSegment, output_device: str = None): """Play audio asynchronously (non-blocking). - + Args: audio: AudioSegment to play output_device: Specific output device to use """ - # Create thread for playback - thread = threading.Thread(target=self._play_audio_blocking, args=(audio, output_device)) + def _safe_play(): + try: + self._play_audio_blocking(audio, output_device) + except Exception as e: + self.logger.error(f"Audio playback failed in background: {e}") + + thread = threading.Thread(target=_safe_play) thread.daemon = True thread.start() diff --git a/src/processing/analysis_storage.py b/src/processing/analysis_storage.py index a2a4941..20b2f28 100644 --- a/src/processing/analysis_storage.py +++ b/src/processing/analysis_storage.py @@ -31,6 +31,7 @@ def db(self): """Get the database instance.""" if self._db is None: from database.database import Database + logger.warning(f"{self.__class__.__name__}: creating local Database instance — prefer passing db= to constructor") self._db = Database() return self._db diff --git a/src/rag/cache/sqlite_provider.py b/src/rag/cache/sqlite_provider.py index b38e433..16e5da5 100644 --- a/src/rag/cache/sqlite_provider.py +++ b/src/rag/cache/sqlite_provider.py @@ -40,6 +40,8 @@ def __init__(self, config: CacheConfig): self._db_path = config.sqlite_path or self._get_default_path() self._local = threading.local() self._lock = threading.Lock() + self._connections: dict = {} + self._conn_lock = threading.Lock() # Stats tracking self._hit_count = 0 @@ -73,6 +75,9 @@ def _get_conn(self) -> sqlite3.Connection: # Enable WAL mode for better concurrency self._local.conn.execute("PRAGMA journal_mode=WAL") self._local.conn.execute("PRAGMA synchronous=NORMAL") + # Track connection for cross-thread cleanup + with self._conn_lock: + self._connections[threading.get_ident()] = self._local.conn return self._local.conn def _init_db(self): @@ -415,10 +420,13 @@ def health_check(self) -> bool: return False def close(self): - """Close database connection.""" - if hasattr(self._local, "conn") and self._local.conn: - try: - self._local.conn.close() - except Exception: - pass + """Close all database connections across all threads.""" + with self._conn_lock: + for thread_id, conn in self._connections.items(): + try: + conn.close() + except Exception: + pass + self._connections.clear() + if hasattr(self._local, "conn"): self._local.conn = None diff --git a/src/rag/guidelines_vector_store.py b/src/rag/guidelines_vector_store.py index 53e4415..a97c99b 100644 --- a/src/rag/guidelines_vector_store.py +++ b/src/rag/guidelines_vector_store.py @@ -410,7 +410,10 @@ def search( elif isinstance(metadata_val, dict): metadata = metadata_val else: - metadata = json.loads(metadata_val) + try: + metadata = json.loads(metadata_val) + except (json.JSONDecodeError, TypeError): + metadata = {} results.append(GuidelineSearchResult( guideline_id=str(guideline_id), @@ -513,7 +516,10 @@ def search_bm25( elif isinstance(metadata_val, dict): metadata = metadata_val else: - metadata = json.loads(metadata_val) + try: + metadata = json.loads(metadata_val) + except (json.JSONDecodeError, TypeError): + metadata = {} # Normalize rank to 0-1 range normalized_score = min(1.0, float(rank) * 10) @@ -909,7 +915,10 @@ def get_guideline_chunks(self, guideline_id: str) -> list[dict]: elif isinstance(metadata_val, dict): metadata = metadata_val else: - metadata = json.loads(metadata_val) + try: + metadata = json.loads(metadata_val) + except (json.JSONDecodeError, TypeError): + metadata = {} results.append({ "id": row[0], diff --git a/src/rag/neon_vector_store.py b/src/rag/neon_vector_store.py index bce0c37..bf3bf6a 100644 --- a/src/rag/neon_vector_store.py +++ b/src/rag/neon_vector_store.py @@ -329,7 +329,10 @@ def search( elif isinstance(metadata_val, dict): metadata = metadata_val else: - metadata = json.loads(metadata_val) + try: + metadata = json.loads(metadata_val) + except (json.JSONDecodeError, TypeError): + metadata = {} results.append(VectorSearchResult( document_id=str(doc_id), @@ -441,7 +444,10 @@ def get_document_chunks(self, document_id: str) -> list[dict]: elif isinstance(metadata_val, dict): metadata = metadata_val else: - metadata = json.loads(metadata_val) + try: + metadata = json.loads(metadata_val) + except (json.JSONDecodeError, TypeError): + metadata = {} results.append({ "id": row[0], @@ -571,7 +577,10 @@ def search_bm25( elif isinstance(metadata_val, dict): metadata = metadata_val else: - metadata = json.loads(metadata_val) + try: + metadata = json.loads(metadata_val) + except (json.JSONDecodeError, TypeError): + metadata = {} # Normalize rank to 0-1 range (ts_rank_cd typically < 1) normalized_score = min(1.0, float(rank) * 10) diff --git a/src/stt_providers/deepgram.py b/src/stt_providers/deepgram.py index 13bac83..0e257ab 100644 --- a/src/stt_providers/deepgram.py +++ b/src/stt_providers/deepgram.py @@ -167,7 +167,7 @@ def transcribe(self, segment: AudioSegment) -> str: if alternatives and "transcript" in alternatives[0]: transcript = alternatives[0]["transcript"] transcript_preview = transcript[:100] + "..." if len(transcript) > 100 else transcript - self.logger.debug(f"Transcript preview: {transcript_preview}") + self.logger.debug(f"Transcript length: {len(transcript)} chars") self.logger.debug("=================================\n") diff --git a/src/stt_providers/modulate.py b/src/stt_providers/modulate.py index 12283f6..74cc1bf 100644 --- a/src/stt_providers/modulate.py +++ b/src/stt_providers/modulate.py @@ -412,7 +412,7 @@ def transcribe(self, segment: AudioSegment, **kwargs) -> str: enable_diarization = settings.get("enable_diarization", True) if enable_diarization and utterances: transcript = self._format_diarized_transcript(utterances) - self.logger.info(f"Diarized transcript preview: {repr(transcript[:200])}") + self.logger.info(f"Diarized transcript length: {len(transcript)} chars") else: transcript = result.get("text", "") self.logger.warning(f"Diarization skipped in transcribe(): " @@ -495,7 +495,7 @@ def transcribe_with_result(self, segment: AudioSegment) -> TranscriptionResult: enable_diarization = settings.get("enable_diarization", True) if enable_diarization and utterances: transcript = self._format_diarized_transcript(utterances) - self.logger.info(f"Diarized transcript preview: {repr(transcript[:200])}") + self.logger.info(f"Diarized transcript length: {len(transcript)} chars") else: transcript = result.get("text", "") self.logger.warning(f"Diarization skipped in transcribe_with_result(): " @@ -636,7 +636,7 @@ def transcribe_file(self, file_path: str) -> tuple: # Format transcript with diarization labels if enable_diarization and utterances: transcript = self._format_diarized_transcript(utterances) - self.logger.info(f"Diarized transcript preview: {repr(transcript[:200])}") + self.logger.info(f"Diarized transcript length: {len(transcript)} chars") else: transcript = result.get("text", "") self.logger.warning(f"Diarization skipped in transcribe_file(): " diff --git a/src/ui/components/streaming_results.py b/src/ui/components/streaming_results.py index c994f6b..c938caf 100644 --- a/src/ui/components/streaming_results.py +++ b/src/ui/components/streaming_results.py @@ -178,12 +178,15 @@ def _animate_to(self, target: float): def animate(): nonlocal current - if abs(current - target) > 1: - current += step - self.progressbar['value'] = current - self.parent.after(50, animate) - else: - self.progressbar['value'] = target + try: + if abs(current - target) > 1: + current += step + self.progressbar['value'] = current + self.parent.after(50, animate) + else: + self.progressbar['value'] = target + except tk.TclError: + pass # Widget destroyed during animation animate() diff --git a/src/ui/dialogs/diagnostic_comparison_dialog.py b/src/ui/dialogs/diagnostic_comparison_dialog.py index a8b738f..64dcb6e 100644 --- a/src/ui/dialogs/diagnostic_comparison_dialog.py +++ b/src/ui/dialogs/diagnostic_comparison_dialog.py @@ -23,23 +23,25 @@ class DiagnosticComparisonDialog: """Dialog for comparing multiple diagnostic analyses side by side.""" - def __init__(self, parent, on_select_callback=None): + def __init__(self, parent, on_select_callback=None, db=None): """Initialize the diagnostic comparison dialog. Args: parent: Parent window on_select_callback: Optional callback when selecting analyses to compare + db: Optional Database instance to reuse (avoids creating a new connection) """ self.parent = parent self.on_select_callback = on_select_callback self.dialog: Optional[tk.Toplevel] = None - self._db: Optional[Database] = None + self._db: Optional[Database] = db self.analyses: List[Dict] = [] self.selected_analyses: List[Dict] = [] def _get_database(self) -> Database: """Get or create database connection.""" if self._db is None: + logger.warning(f"{self.__class__.__name__}: creating local Database instance — prefer passing db= to constructor") self._db = Database() return self._db diff --git a/src/ui/dialogs/diagnostic_history_dialog.py b/src/ui/dialogs/diagnostic_history_dialog.py index 77c37e4..16e229f 100644 --- a/src/ui/dialogs/diagnostic_history_dialog.py +++ b/src/ui/dialogs/diagnostic_history_dialog.py @@ -23,22 +23,24 @@ class DiagnosticHistoryDialog: """Dialog for viewing saved diagnostic analysis history.""" - def __init__(self, parent, on_view_callback=None): + def __init__(self, parent, on_view_callback=None, db=None): """Initialize the diagnostic history dialog. Args: parent: Parent window on_view_callback: Optional callback when viewing an analysis (receives analysis dict) + db: Optional Database instance to reuse (avoids creating a new connection) """ self.parent = parent self.on_view_callback = on_view_callback self.dialog: Optional[tk.Toplevel] = None - self._db: Optional[Database] = None + self._db: Optional[Database] = db self.analyses: List[Dict] = [] def _get_database(self) -> Database: """Get or create database connection.""" if self._db is None: + logger.warning(f"{self.__class__.__name__}: creating local Database instance — prefer passing db= to constructor") self._db = Database() return self._db diff --git a/src/ui/dialogs/help_dialogs.py b/src/ui/dialogs/help_dialogs.py index 9efd119..80548b8 100644 --- a/src/ui/dialogs/help_dialogs.py +++ b/src/ui/dialogs/help_dialogs.py @@ -299,10 +299,13 @@ def open_link(url): dialog.update() def fade_in(alpha=0.0): - if alpha < 1.0: - alpha += 0.1 - dialog.attributes('-alpha', alpha) - dialog.after(20, lambda: fade_in(alpha)) + try: + if alpha < 1.0: + alpha += 0.1 + dialog.attributes('-alpha', alpha) + dialog.after(20, lambda: fade_in(alpha)) + except tk.TclError: + pass # Dialog was closed during animation fade_in() diff --git a/src/ui/dialogs/medication_results_dialog.py b/src/ui/dialogs/medication_results_dialog.py index 331490c..8dd7450 100644 --- a/src/ui/dialogs/medication_results_dialog.py +++ b/src/ui/dialogs/medication_results_dialog.py @@ -24,11 +24,13 @@ class MedicationResultsDialog: """Dialog for displaying medication analysis results.""" - def __init__(self, parent, document_target=None): + def __init__(self, parent, document_target=None, db=None): """Initialize the medication results dialog. Args: parent: Parent window + document_target: Optional document target for inserting content + db: Optional Database instance to reuse (avoids creating a new connection) """ self.parent = parent self._document_target = document_target @@ -42,7 +44,7 @@ def __init__(self, parent, document_target=None): self.patient_context: Optional[Dict[str, Any]] = None self.source_text: str = "" self.dialog: Optional[tk.Toplevel] = None - self._db: Optional[Database] = None + self._db: Optional[Database] = db def show_results( self, @@ -907,6 +909,7 @@ def _export_to_pdf(self): def _get_database(self) -> Database: """Get or create database connection.""" if self._db is None: + logger.warning(f"{self.__class__.__name__}: creating local Database instance — prefer passing db= to constructor") self._db = Database() return self._db diff --git a/src/ui/dialogs/rsvp_dialog.py b/src/ui/dialogs/rsvp_dialog.py index 877d286..33e8d42 100644 --- a/src/ui/dialogs/rsvp_dialog.py +++ b/src/ui/dialogs/rsvp_dialog.py @@ -92,6 +92,7 @@ def __init__(self, parent, text: str): self.current_index = 0 self.is_playing = False self.is_fullscreen = False + self._closed = False self.timer_id: Optional[str] = None # Load settings with validation @@ -704,6 +705,9 @@ def _bind_keys(self) -> None: def _display_word(self) -> None: """Display current word(s) with ORP highlighting on canvas.""" + if self._closed: + return + # Clear canvas self.canvas.delete("all") @@ -1382,8 +1386,12 @@ def _on_resize(self, event) -> None: def _on_close(self) -> None: """Handle dialog close.""" + self._closed = True self.pause() - self.dialog.destroy() + try: + self.dialog.destroy() + except tk.TclError: + pass __all__ = ["RSVPDialog"] diff --git a/src/ui/dialogs/translation/recording.py b/src/ui/dialogs/translation/recording.py index fba6794..5e97754 100644 --- a/src/ui/dialogs/translation/recording.py +++ b/src/ui/dialogs/translation/recording.py @@ -197,7 +197,7 @@ def stop_and_process(): transcript = self.audio_handler.transcribe_audio_without_prefix( combined, diarize_override=False ) - logger.info(f"Transcription result: '{transcript[:100] if transcript else '(empty)'}...'") + logger.info(f"Transcription result length: {len(transcript) if transcript else 0} chars") finally: # Restore original provider if selected_provider: diff --git a/src/ui/dialogs/translation/translation.py b/src/ui/dialogs/translation/translation.py index a0918ee..b99b814 100644 --- a/src/ui/dialogs/translation/translation.py +++ b/src/ui/dialogs/translation/translation.py @@ -58,7 +58,7 @@ def _process_patient_speech(self, transcript: str): Args: transcript: Transcribed text """ - logger.info(f"_process_patient_speech called with transcript: '{transcript[:100] if transcript else '(empty)'}...'") + logger.info(f"_process_patient_speech called, transcript length: {len(transcript) if transcript else 0} chars") # Insert original text self.patient_original_text.delete("1.0", tk.END) self.patient_original_text.insert("1.0", transcript) diff --git a/src/ui/status_manager.py b/src/ui/status_manager.py index 5479d85..f5cfa92 100644 --- a/src/ui/status_manager.py +++ b/src/ui/status_manager.py @@ -122,7 +122,12 @@ def schedule_status_update(self, delay_ms, message, status_type="info"): Returns: Timer ID for the scheduled update """ - timer_id = self.parent.after(delay_ms, lambda: self.update_status(message, status_type)) + def _on_fire(): + if timer_id in self.status_timers: + self.status_timers.remove(timer_id) + self.update_status(message, status_type) + + timer_id = self.parent.after(delay_ms, _on_fire) self.status_timers.append(timer_id) return timer_id diff --git a/src/utils/security/key_storage.py b/src/utils/security/key_storage.py index aa4c8ee..1587c37 100644 --- a/src/utils/security/key_storage.py +++ b/src/utils/security/key_storage.py @@ -144,12 +144,19 @@ def _save_salt(self, salt: bytes): salt: Salt bytes to save """ try: - with open(self.salt_file, 'wb') as f: - f.write(salt) - - # Set restrictive permissions (owner read/write only) if os.name == 'posix': - os.chmod(self.salt_file, 0o600) + # Create file with restrictive permissions from the start + # to avoid TOCTOU window where file is world-readable + fd = os.open(str(self.salt_file), os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600) + try: + with os.fdopen(fd, 'wb') as f: + f.write(salt) + except Exception: + os.close(fd) + raise + else: + with open(self.salt_file, 'wb') as f: + f.write(salt) logger.debug("Salt file saved successfully") except Exception as e: @@ -511,14 +518,20 @@ def _load_keys(self) -> Dict[str, Any]: def _save_keys(self, keys: Dict[str, Any]) -> None: """Save encrypted keys to file.""" try: - # Set restrictive permissions (owner read/write only) - with open(self.key_file, 'w') as f: - json.dump(keys, f, indent=2) - - # Set file permissions (Unix-like systems) + data = json.dumps(keys, indent=2) if os.name == 'posix': - os.chmod(self.key_file, 0o600) - + # Create file with restrictive permissions from the start + # to avoid TOCTOU window where file is world-readable + fd = os.open(str(self.key_file), os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600) + try: + with os.fdopen(fd, 'w') as f: + f.write(data) + except Exception: + os.close(fd) + raise + else: + with open(self.key_file, 'w') as f: + f.write(data) except Exception as e: logger.error(f"Failed to save keys: {e}") raise ConfigurationError(f"Failed to save encrypted keys: {e}")