-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathSmall_Simple_Gradio
More file actions
418 lines (375 loc) · 29 KB
/
Small_Simple_Gradio
File metadata and controls
418 lines (375 loc) · 29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
# -*- coding: utf-8 -*-
# --- Imports ---
import os
import random
import uuid
import json
import time
import traceback # For detailed error printing
import re # Import regex for parsing
# Third-party libraries
import gradio as gr # Import gradio
import torch
import numpy as np
import scipy.io.wavfile as wavfile
import requests
from dotenv import load_dotenv
# --- Whisper Import ---
try:
import whisper
print("Whisper library imported successfully.")
except ImportError:
print("ERROR: Whisper library not found. Please install it:")
print("pip install -U openai-whisper")
print("Ensure ffmpeg is also installed and in your system PATH.")
exit(1)
# --- End Whisper Import ---
# --- SNAC Import ---
try:
from snac import SNAC
except ImportError:
print("ERROR: SNAC library not found. Please install it:")
print("pip install git+https://github.com/hubertsiuzdak/snac.git")
exit(1)
# --- End SNAC Import ---
# --- Load Environment Variables ---
load_dotenv()
# --- Ollama Configuration ---
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://127.0.0.1:11434")
OLLAMA_MODEL = "gemma3:12b"
OLLAMA_API_ENDPOINT = f"{OLLAMA_BASE_URL}/api/chat"
OLLAMA_SYSTEM_PROMPT = "You are a helpful AI assisstant who always limits responses to three complete sentences."
# --- GGUF TTS Server Configuration ---
GGUF_SERVER_URL = os.getenv("GGUF_TTS_SERVER_URL", "http://127.0.0.1:1234")
GGUF_COMPLETIONS_ENDPOINT = f"{GGUF_SERVER_URL}/v1/completions"
# --- Device Setup (SNAC and Whisper) ---
if torch.cuda.is_available():
tts_device = "cuda"; stt_device = "cuda"; print("SNAC vocoder and Whisper STT will use CUDA if possible.")
else:
tts_device = "cpu"; stt_device = "cpu"; print("CUDA not available. SNAC vocoder and Whisper STT will use CPU.")
# --- TTS Model Loading (SNAC ONLY) ---
print("Loading SNAC vocoder model...")
snac_model = None
try: snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz"); snac_model = snac_model.to(tts_device); snac_model.eval(); print(f"SNAC vocoder loaded to {tts_device}")
except Exception as e: print(f"Error loading SNAC model: {e}") # exit(1)
# --- Load Whisper Model ---
WHISPER_MODEL_NAME = "base.en"
print(f"Loading Whisper STT model '{WHISPER_MODEL_NAME}' to {stt_device}...")
try: whisper_model = whisper.load_model(WHISPER_MODEL_NAME, device=stt_device); print(f"Whisper model '{WHISPER_MODEL_NAME}' loaded successfully.")
except Exception as e: print(f"Error loading Whisper model: {e}"); whisper_model = None
print("NOTE: Orpheus TTS model (GGUF) should be loaded in a separate server (e.g., LM Studio).")
# --- Global Parameters & Constants ---
MAX_MAX_NEW_TOKENS = 4096; DEFAULT_OLLAMA_MAX_TOKENS = 256; MAX_SEED = np.iinfo(np.int32).max
ORPHEUS_MIN_ID = 10; ORPHEUS_TOKENS_PER_LAYER = 4096; ORPHEUS_N_LAYERS = 7; ORPHEUS_MAX_ID = ORPHEUS_MIN_ID + (ORPHEUS_N_LAYERS * ORPHEUS_TOKENS_PER_LAYER)
DEFAULT_OLLAMA_TEMP = 0.7; DEFAULT_OLLAMA_TOP_P = 0.9; DEFAULT_OLLAMA_TOP_K = 40; DEFAULT_OLLAMA_REP_PENALTY = 1.1
DEFAULT_TTS_TEMP = 0.4; DEFAULT_TTS_TOP_P = 0.9; DEFAULT_TTS_TOP_K = 40; DEFAULT_TTS_REP_PENALTY = 1.1
CONTEXT_TURN_LIMIT = 3 # Define the number of past turns to include in LLM context
# --- Utility Functions ---
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed: seed = random.randint(0, MAX_SEED)
return seed
def clean_chat_history(limited_chat_history):
cleaned_ollama_format = []
if not limited_chat_history: return []
for user_msg_display, bot_msg_display in limited_chat_history:
user_text = None
if isinstance(user_msg_display, str):
if user_msg_display.startswith("🎤 (Audio Input): "): user_text = user_msg_display.split("🎤 (Audio Input): ", 1)[1]
elif user_msg_display.startswith(("@tara-tts ", "@tara-llm ")): user_text = user_msg_display.split(" ", 1)[1]
else: user_text = user_msg_display
elif isinstance(user_msg_display, tuple): # Fallback
if len(user_msg_display)>1 and isinstance(user_msg_display[1], str): user_text = user_msg_display[1].replace("🎤: ","")
elif isinstance(user_msg_display[0],str) and not user_msg_display[0].endswith((".wav",".mp3")): user_text = user_msg_display[0]
bot_text = None
if isinstance(bot_msg_display, tuple):
if len(bot_msg_display)>1 and isinstance(bot_msg_display[1], str): bot_text = bot_msg_display[1]
elif isinstance(bot_msg_display, str):
if not bot_msg_display.startswith(("[Error", "(Error", "Sorry,", "(No input", "Processing", "(TTS failed")): bot_text = bot_msg_display
if user_text and user_text.strip(): cleaned_ollama_format.append({"role": "user", "content": user_text})
if bot_text and bot_text.strip(): cleaned_ollama_format.append({"role": "assistant", "content": bot_text})
return cleaned_ollama_format
# --- TTS Pipeline Functions (GGUF Server + SNAC) ---
def parse_gguf_codes(response_text):
absolute_ids = []; matches = re.findall(r"<custom_token_(\d+)>", response_text)
if not matches: return []
for number_str in matches:
try: token_id = int(number_str);
except ValueError: continue
if ORPHEUS_MIN_ID <= token_id < ORPHEUS_MAX_ID: absolute_ids.append(token_id)
if not absolute_ids: return []
print(f" - Parsed {len(absolute_ids)} valid audio token IDs using regex."); return absolute_ids
def redistribute_codes(absolute_code_list, target_snac_model):
if not absolute_code_list or target_snac_model is None: return None
snac_device = next(target_snac_model.parameters()).device; layer_1, layer_2, layer_3 = [], [], []
num_tokens = len(absolute_code_list); num_groups = num_tokens // ORPHEUS_N_LAYERS
if num_groups == 0: return None
print(f" - Processing {num_groups} groups of {ORPHEUS_N_LAYERS} codes for SNAC...")
for i in range(num_groups):
base_idx = i * ORPHEUS_N_LAYERS;
if base_idx + ORPHEUS_N_LAYERS > num_tokens: break
group_codes = absolute_code_list[base_idx : base_idx + ORPHEUS_N_LAYERS]; processed_group = [None] * ORPHEUS_N_LAYERS; valid_group = True
for j, token_id in enumerate(group_codes):
if not (ORPHEUS_MIN_ID <= token_id < ORPHEUS_MAX_ID): valid_group = False; break
layer_index = (token_id - ORPHEUS_MIN_ID) // ORPHEUS_TOKENS_PER_LAYER; code_index = (token_id - ORPHEUS_MIN_ID) % ORPHEUS_TOKENS_PER_LAYER
if layer_index != j: valid_group = False; break
processed_group[j] = code_index
if not valid_group: continue
try: layer_1.append(processed_group[0]); layer_2.append(processed_group[1]); layer_3.append(processed_group[2]); layer_3.append(processed_group[3]); layer_2.append(processed_group[4]); layer_3.append(processed_group[5]); layer_3.append(processed_group[6])
except (IndexError, TypeError): continue
try:
if not layer_1 or not layer_2 or not layer_3: return None
print(f" - Final SNAC layer sizes: L1={len(layer_1)}, L2={len(layer_2)}, L3={len(layer_3)}")
codes = [ torch.tensor(layer_1, device=snac_device, dtype=torch.long).unsqueeze(0), torch.tensor(layer_2, device=snac_device, dtype=torch.long).unsqueeze(0), torch.tensor(layer_3, device=snac_device, dtype=torch.long).unsqueeze(0) ]
with torch.no_grad(): audio_hat = target_snac_model.decode(codes)
return audio_hat.detach().squeeze().cpu().numpy()
except Exception as e: print(f"Error during tensor creation or SNAC decoding: {e}"); return None
def generate_speech_gguf(text, voice, tts_temperature, tts_top_p, tts_repetition_penalty, max_new_tokens_audio):
if not text.strip() or snac_model is None: return None
print(f"Generating speech via GGUF server for: '{text[:50]}...'"); start_time = time.time()
prompt_string = f"<|audio|>{voice}: {text}<|eot_id|>"; payload = { "prompt": prompt_string, "temperature": tts_temperature, "top_p": tts_top_p, "repeat_penalty": tts_repetition_penalty, "n_predict": max_new_tokens_audio, "stop": ["<|eot_id|>", "<|audio|>"], "stream": False, }
print(f" - Sending payload to {GGUF_COMPLETIONS_ENDPOINT} (TTS Temp: {tts_temperature}, TTS Top-P: {tts_top_p}, TTS RepPen: {tts_repetition_penalty})")
try:
response = requests.post(GGUF_COMPLETIONS_ENDPOINT, json=payload, timeout=180); response.raise_for_status(); response_json = response.json()
if "choices" in response_json and len(response_json["choices"]) > 0 and "text" in response_json["choices"][0]: raw_generated_text = response_json["choices"][0]["text"].strip();
else: print(f"Error: Unexpected response format: {response_json}"); return None
if not raw_generated_text: return None
req_time = time.time(); print(f" - GGUF server request took {req_time - start_time:.2f}s")
absolute_id_list = parse_gguf_codes(raw_generated_text)
if not absolute_id_list: print("Error: No valid audio codes parsed."); return None
audio_samples = redistribute_codes(absolute_id_list, snac_model)
if audio_samples is None: print("Error: Failed to generate audio samples."); return None
snac_time = time.time(); print(f" - Generated audio samples via SNAC, shape: {audio_samples.shape} (took {snac_time - req_time:.2f}s)")
print(f" - Total TTS generation time: {snac_time - start_time:.2f}s"); return (24000, audio_samples)
except requests.exceptions.RequestException as e: print(f"Error during request to GGUF server: {e}"); return None
except Exception as e: print(f"Error during GGUF speech generation pipeline: {e}"); traceback.print_exc(); return None
# --- Ollama Communication Helper ---
def call_ollama_non_streaming(ollama_payload, generation_params):
final_response = "[Error: Default response]"
try:
options = { "temperature": generation_params.get('ollama_temperature', DEFAULT_OLLAMA_TEMP), "top_p": generation_params.get('ollama_top_p', DEFAULT_OLLAMA_TOP_P), "top_k": generation_params.get('ollama_top_k', DEFAULT_OLLAMA_TOP_K), "num_predict": generation_params.get('ollama_max_new_tokens', DEFAULT_OLLAMA_MAX_TOKENS), "repeat_penalty": generation_params.get('ollama_repetition_penalty', DEFAULT_OLLAMA_REP_PENALTY), }
ollama_payload['options'] = {k: v for k, v in options.items() if v is not None}; ollama_payload['stream'] = False
print(f" - Sending non-streaming request to Ollama with options: {ollama_payload.get('options', {})}")
start_time = time.time(); response = requests.post( OLLAMA_API_ENDPOINT, json=ollama_payload, timeout=180 ); response.raise_for_status(); response_json = response.json(); end_time = time.time()
print(f" - Ollama request took {end_time - start_time:.2f}s")
if "message" in response_json and "content" in response_json["message"]: final_response = response_json["message"]["content"].strip()
elif "error" in response_json: final_response = f"[Error from Ollama: {response_json['error']}]"; print(final_response)
else: final_response = "[Error: Unexpected Ollama response format]"
except requests.exceptions.RequestException as e: final_response = f"[Error connecting to Ollama: {e}]"
except Exception as e: final_response = f"[Unexpected Ollama Error: {e}]"; traceback.print_exc()
print(f" - Ollama response: '{final_response[:100]}...'"); return final_response
# --- Main Gradio Backend Function --- #
def process_input_blocks(
text_input: str, audio_input_path: str,
# ---> ADDED Third Checkbox Input <---
auto_prefix_tts_checkbox: bool,
auto_prefix_llm_checkbox: bool,
plain_llm_checkbox: bool,
# --- End Add ---
ollama_max_new_tokens: int, ollama_temperature: float, ollama_top_p: float, ollama_top_k: int, ollama_repetition_penalty: float,
tts_temperature: float, tts_top_p: float, tts_repetition_penalty: float,
chat_history: list
):
global whisper_model, snac_model
original_user_input_text = ""; user_display_input = None; text_to_process = ""; transcription_source = "text"
bot_response = ""; bot_audio_tuple = None; audio_filepath_to_clean = None; is_purely_text_input = False; prefix_to_add = None
force_plain_llm = False # Flag to force Branch 3
# --- 1. Determine Input & Apply Prefix ---
known_prefixes = ["@tara-tts", "@jess-tts", "@leo-tts", "@leah-tts", "@dan-tts", "@mia-tts", "@zac-tts", "@zoe-tts","@tara-llm", "@jess-llm", "@leo-llm", "@leah-llm", "@dan-llm", "@mia-llm", "@zac-llm", "@zoe-llm"]
# Handle Audio Input
if audio_input_path and whisper_model:
if os.path.isfile(audio_input_path):
audio_filepath_to_clean = audio_input_path; transcription_source = "voice"; print(f"Processing audio input: {audio_input_path}")
try:
stt_start_time = time.time(); result = whisper_model.transcribe(audio_input_path, fp16=(stt_device == 'cuda')); original_user_input_text = result["text"].strip(); stt_end_time = time.time()
print(f" - Whisper transcription: '{original_user_input_text}' (took {stt_end_time - stt_start_time:.2f}s)")
user_display_input = f"🎤 (Audio Input): {original_user_input_text}" # Display text
text_to_process = original_user_input_text # Start with transcription
# Check if transcription itself is already a command
is_already_command = any(original_user_input_text.lower().startswith(p) for p in known_prefixes)
if not is_already_command:
# Apply checkbox logic to transcribed audio text
if plain_llm_checkbox:
prefix_to_add = None; force_plain_llm = True; print(f" - Plain LLM checked. Processing audio as text input for LLM.")
elif auto_prefix_tts_checkbox:
prefix_to_add = "@tara-tts"; print(f" - Auto-prefix TTS checked. Applying to audio.")
elif auto_prefix_llm_checkbox:
prefix_to_add = "@tara-llm"; print(f" - Auto-prefix LLM checked. Applying to audio.")
else: print(f" - No default prefix checkbox checked for audio. Processing as text for LLM.")
if prefix_to_add: text_to_process = f"{prefix_to_add} {original_user_input_text}"
# If no prefix and not forcing plain LLM, text_to_process remains original transcription
# If forcing plain LLM, text_to_process remains original (will hit branch 3)
else:
print(f" - Transcribed audio is already a command '{original_user_input_text[:20]}...'.")
text_to_process = original_user_input_text # Use the command directly
except Exception as e: # Handle Transcription Error
print(f"Error during Whisper transcription: {e}"); traceback.print_exc(); error_msg = f"[Error during local transcription: {e}]"
chat_history.append( (f"🎤 (Audio Input Error: {audio_input_path})", error_msg) )
if audio_filepath_to_clean and os.path.exists(audio_filepath_to_clean):
try: os.remove(audio_filepath_to_clean)
except Exception as e_clean: print(f"Warning: Could not clean up STT temp file {audio_filepath_to_clean}: {e_clean}")
return chat_history, None, None
else: print(f"Received invalid audio path: {audio_input_path}, falling back to text.")
# Handle Text Input (if no valid audio processed)
if not text_to_process and text_input:
original_user_input_text = text_input.strip(); user_display_input = original_user_input_text; print(f"Processing text input: '{original_user_input_text}'")
transcription_source = "text"; text_to_process = original_user_input_text
is_already_command = any(original_user_input_text.lower().startswith(p) for p in known_prefixes)
if not is_already_command:
if plain_llm_checkbox:
prefix_to_add = None; force_plain_llm = True; print(f" - Plain LLM checked. Processing text input for LLM.")
elif auto_prefix_tts_checkbox:
prefix_to_add = "@tara-tts"; print(f" - Auto-prefix TTS checked. Applying to text.")
elif auto_prefix_llm_checkbox:
prefix_to_add = "@tara-llm"; print(f" - Auto-prefix LLM checked. Applying to text.")
else: print(f" - No default prefix checkbox enabled for text input.")
else: print(f" - User provided command in text '{original_user_input_text[:20]}...', not auto-prepending.")
if prefix_to_add: text_to_process = f"{prefix_to_add} {original_user_input_text}"
# Cleanup successful transcription file
if audio_filepath_to_clean and os.path.exists(audio_filepath_to_clean):
try: os.remove(audio_filepath_to_clean); print(f" - Cleaned up temporary STT audio file: {audio_filepath_to_clean}")
except Exception as e_clean: print(f"Warning: Could not clean up temp STT audio file {audio_filepath_to_clean}: {e_clean}")
if not text_to_process: print("No valid text or audio input to process."); return chat_history, None, None # Handle Empty
chat_history.append( (user_display_input, None) ) # Add user message to history
# --- 2. Process Input Text (Routing) ---
lower_text = text_to_process.lower(); print(f" - Routing query ({transcription_source}): '{text_to_process[:100]}...'")
all_voices = ["tara", "jess", "leo", "leah", "dan", "mia", "zac", "zoe"]; tts_tags = {f"@{voice}-tts": voice for voice in all_voices}; llm_tags = {f"@{voice}-llm": voice for voice in all_voices}
final_bot_message = None
try:
matched_tts = False; matched_llm_tts = False # Reset flags
# --- Check Branches ONLY IF NOT Forcing Plain LLM ---
if not force_plain_llm:
# Branch 1: Direct TTS
for tag, voice in tts_tags.items():
if lower_text.startswith(tag):
matched_tts = True; text_to_speak = text_to_process[len(tag):].strip(); print(f" - Direct TTS request (GGUF) for voice '{voice}': '{text_to_speak[:50]}...'")
if snac_model is None: raise ValueError("SNAC vocoder not loaded.")
audio_output = generate_speech_gguf( text_to_speak, voice, tts_temperature, tts_top_p, tts_repetition_penalty, MAX_MAX_NEW_TOKENS )
if audio_output:
sample_rate, audio_data = audio_output
if audio_data.dtype != np.int16:
if np.issubdtype(audio_data.dtype, np.floating): max_val=np.max(np.abs(audio_data)); audio_data=np.int16(audio_data/max_val*32767) if max_val>1e-6 else np.zeros_like(audio_data,dtype=np.int16)
else: audio_data = audio_data.astype(np.int16)
temp_dir="temp_audio_files"; os.makedirs(temp_dir,exist_ok=True); temp_audio_path=os.path.join(temp_dir,f"temp_audio_{uuid.uuid4().hex}.wav")
wavfile.write(temp_audio_path, sample_rate, audio_data); print(f" - Saved TTS audio: {temp_audio_path}"); final_bot_message = (temp_audio_path, None)
else: final_bot_message = f"Sorry, couldn't generate speech for '{text_to_speak[:50]}...'."
break # Exit TTS tag loop
# Branch 2: LLM + TTS
if not matched_tts:
for tag, voice in llm_tags.items():
if lower_text.startswith(tag):
matched_llm_tts=True; prompt_for_llm=text_to_process[len(tag):].strip(); print(f" - Ollama+TTS request (GGUF) voice '{voice}', prompt: '{prompt_for_llm[:75]}...'")
if snac_model is None: raise ValueError("SNAC vocoder not loaded.")
history_before_current=chat_history[:-1]; limited_history_turns=history_before_current[-CONTEXT_TURN_LIMIT:]; cleaned_hist_for_llm=clean_chat_history(limited_history_turns)
messages=[{"role":"system","content":OLLAMA_SYSTEM_PROMPT}]+cleaned_hist_for_llm+[{"role":"user","content":prompt_for_llm}]; payload={"model":OLLAMA_MODEL,"messages":messages}
llm_params={'ollama_temperature':ollama_temperature,'ollama_top_p':ollama_top_p,'ollama_top_k':ollama_top_k,'ollama_max_new_tokens':ollama_max_new_tokens,'ollama_repetition_penalty':ollama_repetition_penalty}
llm_response_text=call_ollama_non_streaming(payload,llm_params)
if llm_response_text and not llm_response_text.startswith("[Error"):
audio_output=generate_speech_gguf(llm_response_text,voice,tts_temperature,tts_top_p,tts_repetition_penalty,MAX_MAX_NEW_TOKENS)
if audio_output:
sample_rate,audio_data=audio_output
if audio_data.dtype!=np.int16:
if np.issubdtype(audio_data.dtype,np.floating): max_val=np.max(np.abs(audio_data)); audio_data=np.int16(audio_data/max_val*32767) if max_val>1e-6 else np.zeros_like(audio_data,dtype=np.int16)
else: audio_data=audio_data.astype(np.int16)
temp_dir="temp_audio_files"; os.makedirs(temp_dir,exist_ok=True); temp_audio_path=os.path.join(temp_dir,f"temp_audio_{uuid.uuid4().hex}.wav")
wavfile.write(temp_audio_path,sample_rate,audio_data); print(f" - Saved LLM+TTS audio: {temp_audio_path}"); final_bot_message=(temp_audio_path,llm_response_text)
else: print("Warning: GGUF TTS generation failed..."); final_bot_message=f"{llm_response_text}\n\n(TTS failed...)"
else: final_bot_message=llm_response_text
break # Exit LLM tag loop
# --- Branch 3: Default/Plain LLM Text Chat ---
# Reached if force_plain_llm is True OR if no command tag was matched/added.
if force_plain_llm or (not matched_tts and not matched_llm_tts):
if force_plain_llm: print(f" - Plain LLM chat mode forced by checkbox...")
else: print(f" - Default Ollama text chat (no command prefix detected/added)...")
history_before_current=chat_history[:-1]; limited_history_turns=history_before_current[-CONTEXT_TURN_LIMIT:]; cleaned_hist_for_llm=clean_chat_history(limited_history_turns)
# Use the ORIGINAL user input text for the LLM context in this case
messages=[{"role":"system","content":OLLAMA_SYSTEM_PROMPT}]+cleaned_hist_for_llm+[{"role":"user","content":original_user_input_text}]; payload={"model":OLLAMA_MODEL,"messages":messages}
llm_params={'ollama_temperature':ollama_temperature,'ollama_top_p':ollama_top_p,'ollama_top_k':ollama_top_k,'ollama_max_new_tokens':ollama_max_new_tokens,'ollama_repetition_penalty':ollama_repetition_penalty}
# IMPORTANT: Assign directly to final_bot_message as TEXT
final_bot_message=call_ollama_non_streaming(payload,llm_params)
except Exception as e: print(f"Error during processing: {e}"); traceback.print_exc(); final_bot_message=f"[An unexpected error occurred: {e}]"
# --- 3. Update History and Return ---
chat_history[-1] = (user_display_input, final_bot_message)
return chat_history, None, None # Return history, clear text, clear audio
# --- Helper functions for Checkbox Interaction ---
# Now need functions to handle three checkboxes ensuring only one is checked
def update_prefix_checkboxes(selected_checkbox_label):
"""Unchecks other prefix checkboxes when one is selected."""
if selected_checkbox_label == "tts":
return gr.update(value=True), gr.update(value=False), gr.update(value=False)
elif selected_checkbox_label == "llm":
return gr.update(value=False), gr.update(value=True), gr.update(value=False)
elif selected_checkbox_label == "plain":
return gr.update(value=False), gr.update(value=False), gr.update(value=True)
else: # Should not happen, but return current state
return gr.update(), gr.update(), gr.update()
# --- Gradio Interface Definition using gr.Blocks ---
print("Setting up Gradio Interface with gr.Blocks...")
theme_to_use = None
with gr.Blocks(theme=theme_to_use) as demo:
gr.Markdown(f"# Orpheus Edge 🎤 Ollama ({OLLAMA_MODEL.split(':')[0]}) Chat & Local GGUF TTS (Whisper STT + gr.Audio)")
chatbot = gr.Chatbot(label="Chat History", height=500)
with gr.Row():
with gr.Column(scale=3): text_input_box = gr.Textbox(label="Type your message or use microphone", lines=2)
with gr.Column(scale=1): audio_input_mic = gr.Audio(label="Record Audio Input", type="filepath")
# ---> ADDED Third Checkbox <---
with gr.Row():
auto_prefix_tts_checkbox = gr.Checkbox( label="Default to TTS (@tara-tts)", value=True, elem_id="cb_tts" )
auto_prefix_llm_checkbox = gr.Checkbox( label="Default to LLM+TTS (@tara-llm)", value=False, elem_id="cb_llm" )
plain_llm_checkbox = gr.Checkbox( label="Plain LLM Chat (Text Out)", value=False, elem_id="cb_plain" )
# --- End Add ---
with gr.Row():
submit_button = gr.Button("Send / Submit")
clear_button = gr.ClearButton([text_input_box, audio_input_mic, chatbot])
with gr.Accordion("Generation Parameters", open=False) as param_accordion:
gr.Markdown("### Ollama LLM Parameters")
ollama_max_new_tokens_slider = gr.Slider(label="LLM Max New Tokens", minimum=32, maximum=4096, step=32, value=DEFAULT_OLLAMA_MAX_TOKENS)
ollama_temperature_slider = gr.Slider(label="LLM Temperature", minimum=0.0, maximum=2.0, step=0.05, value=DEFAULT_OLLAMA_TEMP)
ollama_top_p_slider = gr.Slider(label="LLM Top-p", minimum=0.05, maximum=1.0, step=0.05, value=DEFAULT_OLLAMA_TOP_P)
ollama_top_k_slider = gr.Slider(label="LLM Top-k", minimum=1, maximum=100, step=1, value=DEFAULT_OLLAMA_TOP_K)
ollama_repetition_penalty_slider = gr.Slider(label="LLM Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=DEFAULT_OLLAMA_REP_PENALTY)
gr.Markdown("---")
gr.Markdown("### GGUF TTS Server Parameters")
tts_temperature_slider = gr.Slider(label="TTS Temperature", minimum=0.0, maximum=2.0, step=0.05, value=DEFAULT_TTS_TEMP)
tts_top_p_slider = gr.Slider(label="TTS Top-p", minimum=0.05, maximum=1.0, step=0.05, value=DEFAULT_TTS_TOP_P)
tts_repetition_penalty_slider = gr.Slider(label="TTS Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=DEFAULT_TTS_REP_PENALTY)
param_inputs = [ ollama_max_new_tokens_slider, ollama_temperature_slider, ollama_top_p_slider, ollama_top_k_slider, ollama_repetition_penalty_slider, tts_temperature_slider, tts_top_p_slider, tts_repetition_penalty_slider ]
# ---> Define Checkbox Interactions for Mutual Exclusivity <---
# Use a lambda to pass which checkbox was clicked to the handler
auto_prefix_tts_checkbox.change(
fn=lambda: update_prefix_checkboxes("tts"),
inputs=[], # No direct input needed, the handler knows which one was clicked implicitly via lambda
outputs=[auto_prefix_tts_checkbox, auto_prefix_llm_checkbox, plain_llm_checkbox]
)
auto_prefix_llm_checkbox.change(
fn=lambda: update_prefix_checkboxes("llm"),
inputs=[],
outputs=[auto_prefix_tts_checkbox, auto_prefix_llm_checkbox, plain_llm_checkbox]
)
plain_llm_checkbox.change(
fn=lambda: update_prefix_checkboxes("plain"),
inputs=[],
outputs=[auto_prefix_tts_checkbox, auto_prefix_llm_checkbox, plain_llm_checkbox]
)
# --- End Interactions ---
# ---> UPDATED Inputs list for submit actions <---
all_inputs = [text_input_box, audio_input_mic, auto_prefix_tts_checkbox, auto_prefix_llm_checkbox, plain_llm_checkbox] + param_inputs + [chatbot]
# --- End Update ---
submit_button.click( fn=process_input_blocks, inputs=all_inputs, outputs=[chatbot, text_input_box, audio_input_mic] )
text_input_box.submit( fn=process_input_blocks, inputs=all_inputs, outputs=[chatbot, text_input_box, audio_input_mic] )
# --- Application Entry Point ---
if __name__ == "__main__":
print("-" * 50); print(f"Launching Gradio {gr.__version__} Interface (gr.Blocks with gr.Audio)...")
if whisper_model is None: print("WARNING: Whisper model failed...")
else: print(f"Whisper STT Model: {WHISPER_MODEL_NAME} on {stt_device}")
if snac_model is None: print("WARNING: SNAC vocoder not loaded...")
else: print(f"SNAC Vocoder loaded to {tts_device}")
print(f"Ollama URL: {OLLAMA_BASE_URL}, Model: {OLLAMA_MODEL}")
print(f"GGUF TTS Server URL: {GGUF_SERVER_URL}"); print("-" * 50)
print("Default Parameters:"); print(f" LLM: Temp={DEFAULT_OLLAMA_TEMP}, TopP={DEFAULT_OLLAMA_TOP_P}, TopK={DEFAULT_OLLAMA_TOP_K}, RepPen={DEFAULT_OLLAMA_REP_PENALTY}, MaxT={DEFAULT_OLLAMA_MAX_TOKENS}"); print(f" TTS: Temp={DEFAULT_TTS_TEMP}, TopP={DEFAULT_TTS_TOP_P}, RepPen={DEFAULT_TTS_REP_PENALTY}"); print("-" * 50)
print("Ensure Ollama and the GGUF TTS server are running.")
print("Use 'Record Audio Input' or the Text Box."); print("Use checkboxes to control default behavior for input (TTS, LLM+TTS, or Plain LLM).")
print("Then click 'Send / Submit'."); print("-" * 50)
print("!!! IMPORTANT: Check browser console (F12 -> Console) for errors if audio component fails. !!!")
os.makedirs("temp_audio_files", exist_ok=True)
demo.launch(share=False)
print("Gradio Interface launched. Press Ctrl+C to stop.")