diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index cbaeef7..87ed1ce 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -112,6 +112,13 @@ Top-level directories (one level down): - Prefer frontend changes inside features// and keep API calls in services/. - Trust these instructions and only search the repo if something is missing or contradicts these notes. +## Test Data Safety (Mandatory) + +- Never read from or write to real user runtime files under `data/config/`, `data/projects/`, or `data/logs/` when running tests. +- For all automated tests, force the app to use a temporary user data root by setting `AUGQ_USER_DATA_DIR` to a temp directory (for example under `/tmp`). +- Tests and AI-generated test code must isolate runtime paths via environment variables before importing app modules so default config constants resolve into temp paths. +- When tests need projects/registry overrides, use temp values for `AUGQ_PROJECTS_ROOT` and `AUGQ_PROJECTS_REGISTRY` inside that same temp root. + ## Branching and Release Policy The repository uses the following branch layout by default: diff --git a/.gitignore b/.gitignore index 8d48322..b50c628 100644 --- a/.gitignore +++ b/.gitignore @@ -218,6 +218,7 @@ __marimo__/ # User configuration and data resources/config/*.json !resources/config/examples/*.json +!resources/config/model_presets.json data/ # Build artifacts diff --git a/README.md b/README.md index 7e4c8d9..0fff518 100644 --- a/README.md +++ b/README.md @@ -81,11 +81,13 @@ If you want to modify the frontend and see changes on the fly: Configuration is JSON-based with environment variable precedence and interpolation. -- Machine-specific config (API credentials/endpoints): resources/config/machine.json -- Story-specific config (active project): resources/config/story.json +- Runtime machine config (local user setting, not tracked): data/config/machine.json +- Runtime story fallback config (local user setting, not tracked): data/config/story.json +- Runtime projects registry (local user setting, not tracked): data/config/projects.json +- Project-shipped model presets database (tracked): resources/config/model_presets.json - Environment variables always override JSON values. JSON may include placeholders like ${OPENAI_API_KEY}. -Sample files can be found under resources/config/examples/: +Sample files can be found under resources/config/examples/ (tracked and for inspiration only; the app does not auto-load them as runtime config): - resources/config/examples/machine.json - resources/config/examples/story.json diff --git a/resources/config/model_presets.json b/resources/config/model_presets.json new file mode 100644 index 0000000..22f3758 --- /dev/null +++ b/resources/config/model_presets.json @@ -0,0 +1,100 @@ +{ + "presets": [ + { + "id": "openai-balanced-chat", + "name": "Balanced Chat", + "description": "Balanced default for general chat and co-writing tasks.", + "model_id_patterns": ["^gpt-4(\\.|-|o)", "^gpt-5"], + "parameters": { + "temperature": 0.7, + "top_p": 0.95, + "max_tokens": 2048, + "presence_penalty": 0, + "frequency_penalty": 0, + "stop": [], + "seed": null, + "top_k": null, + "min_p": null, + "extra_body": "" + } + }, + { + "id": "reasoning-low-temp", + "name": "Reasoning (Low Temp)", + "description": "Conservative sampling for analysis/editing and deterministic behavior.", + "model_id_patterns": ["reason", "^o[13]", "^gpt-5"], + "parameters": { + "temperature": 0.2, + "top_p": 0.9, + "max_tokens": 4096, + "presence_penalty": 0, + "frequency_penalty": 0, + "stop": [], + "seed": 42, + "top_k": null, + "min_p": null, + "extra_body": "" + } + }, + { + "id": "thinking-heavy-chat", + "name": "Thinking Heavy", + "description": "High-latency reasoning-focused setup. Usually a poor default for writing flow.", + "model_id_patterns": ["thinking", "reasoning"], + "parameters": { + "temperature": 0.6, + "top_p": 0.95, + "max_tokens": 4096, + "presence_penalty": 0, + "frequency_penalty": 0, + "stop": [], + "seed": null, + "top_k": null, + "min_p": null, + "extra_body": "{\"reasoning\": {\"enabled\": true}}" + }, + "warnings": { + "writing": "This preset is optimized for deep reasoning and may slow writing flow or produce overlong internal reasoning output." + } + }, + { + "id": "qwen35-thinking-default", + "name": "Qwen 3.5 Thinking", + "description": "Qwen3.5 recommended sampling for default thinking mode.", + "model_id_patterns": ["^qwen(/|-)?.*3\\.5", "Qwen3\\.5", "Qwen/Qwen3\\.5"], + "parameters": { + "temperature": 0.6, + "top_p": 0.95, + "max_tokens": null, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "stop": [], + "seed": null, + "top_k": 20, + "min_p": 0.0, + "extra_body": "{\"repetition_penalty\": 1.0}" + }, + "warnings": { + "writing": "This is a thinking-mode preset and should not be used for WRITING; prefer the non-thinking preset for drafting flow." + } + }, + { + "id": "qwen35-instruct-non-thinking", + "name": "Qwen 3.5 Instruct (Non-Thinking)", + "description": "Qwen3.5 recommended non-thinking sampling profile with direct responses.", + "model_id_patterns": ["^qwen(/|-)?.*3\\.5", "Qwen3\\.5", "Qwen/Qwen3\\.5"], + "parameters": { + "temperature": 0.7, + "top_p": 0.8, + "max_tokens": null, + "presence_penalty": 1.5, + "frequency_penalty": 0.0, + "stop": [], + "seed": null, + "top_k": 20, + "min_p": 0.0, + "extra_body": "{\"repetition_penalty\": 1.0, \"chat_template_kwargs\": {\"enable_thinking\": false}}" + } + } + ] +} diff --git a/resources/schemas/README.md b/resources/schemas/README.md index 14032b2..3f66cea 100644 --- a/resources/schemas/README.md +++ b/resources/schemas/README.md @@ -5,8 +5,9 @@ This directory contains JSON Schema files for validating configuration and proje ## Files - `story-v2.schema.json`: Schema for `story.json` files in project directories, with metadata.version = 2. -- `projects.schema.json`: Schema for `resources/config/projects.json`. -- `machine.schema.json`: Schema for `resources/config/machine.json`. +- `projects.schema.json`: Schema for runtime `projects.json` settings. +- `machine.schema.json`: Schema for runtime `data/config/machine.json` settings. +- `model_presets.schema.json`: Schema for tracked preset DB `resources/config/model_presets.json`. ## Usage diff --git a/resources/schemas/machine.schema.json b/resources/schemas/machine.schema.json index 9b84cf4..77c9c2a 100644 --- a/resources/schemas/machine.schema.json +++ b/resources/schemas/machine.schema.json @@ -30,6 +30,57 @@ "model": { "type": "string" }, + "temperature": { + "type": ["number", "null"], + "minimum": 0, + "maximum": 2 + }, + "top_p": { + "type": ["number", "null"], + "minimum": 0, + "maximum": 1 + }, + "max_tokens": { + "type": ["integer", "null"], + "minimum": 1 + }, + "presence_penalty": { + "type": ["number", "null"], + "minimum": -2, + "maximum": 2 + }, + "frequency_penalty": { + "type": ["number", "null"], + "minimum": -2, + "maximum": 2 + }, + "stop": { + "type": ["array", "null"], + "items": { + "type": "string" + } + }, + "seed": { + "type": ["integer", "null"] + }, + "top_k": { + "type": ["integer", "null"], + "minimum": 1 + }, + "min_p": { + "type": ["number", "null"], + "minimum": 0, + "maximum": 1 + }, + "extra_body": { + "type": ["string", "null"] + }, + "preset_id": { + "type": ["string", "null"] + }, + "writing_warning": { + "type": ["string", "null"] + }, "is_multimodal": { "type": ["boolean", "null"] }, @@ -49,6 +100,15 @@ "selected": { "type": "string", "description": "The selected model name" + }, + "selected_chat": { + "type": "string" + }, + "selected_writing": { + "type": "string" + }, + "selected_editing": { + "type": "string" } }, "required": ["models", "selected"] diff --git a/resources/schemas/model_presets.schema.json b/resources/schemas/model_presets.schema.json new file mode 100644 index 0000000..e4bf50a --- /dev/null +++ b/resources/schemas/model_presets.schema.json @@ -0,0 +1,46 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://example.com/model_presets.schema.json", + "title": "Model Preset Database", + "type": "object", + "properties": { + "presets": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": { + "type": "string" + }, + "name": { + "type": "string" + }, + "description": { + "type": "string" + }, + "model_id_patterns": { + "type": "array", + "items": { + "type": "string" + } + }, + "parameters": { + "type": "object", + "additionalProperties": true + }, + "warnings": { + "type": "object", + "properties": { + "writing": { + "type": "string" + } + }, + "additionalProperties": true + } + }, + "required": ["id", "name", "description", "model_id_patterns", "parameters"] + } + } + }, + "required": ["presets"] +} diff --git a/src/augmentedquill/api/v1/chat.py b/src/augmentedquill/api/v1/chat.py index e8cba3a..2552f9d 100644 --- a/src/augmentedquill/api/v1/chat.py +++ b/src/augmentedquill/api/v1/chat.py @@ -15,7 +15,11 @@ from fastapi import APIRouter, Request, HTTPException from fastapi.responses import JSONResponse, StreamingResponse -from augmentedquill.core.config import load_machine_config, CONFIG_DIR +from augmentedquill.core.config import ( + load_machine_config, + DEFAULT_MACHINE_CONFIG_PATH, + DEFAULT_STORY_CONFIG_PATH, +) from augmentedquill.api.v1.http_responses import error_json, ok_json from augmentedquill.services.projects.projects import get_active_project_dir from augmentedquill.services.llm.llm import add_llm_log, create_log_entry @@ -51,7 +55,7 @@ @router.get("/chat", response_model=ChatInitialStateResponse) async def api_get_chat() -> ChatInitialStateResponse: """Return initial state for chat view: models and current selection.""" - machine = load_machine_config(CONFIG_DIR / "machine.json") or {} + machine = load_machine_config(DEFAULT_MACHINE_CONFIG_PATH) or {} openai_cfg = (machine.get("openai") or {}) if isinstance(machine, dict) else {} models_list = openai_cfg.get("models") if isinstance(openai_cfg, dict) else [] @@ -151,7 +155,7 @@ async def api_chat_stream(request: Request) -> StreamingResponse: "model_name": "name-of-configured-entry" | null, "model_type": "CHAT" | "WRITING" | "EDITING" | null, "messages": [{"role": "system|user|assistant", "content": str}, ...], - // optional overrides (otherwise pulled from resources/config/machine.json) + // optional overrides (otherwise pulled from runtime user machine config) "base_url": str, "api_key": str, "model": str, @@ -170,7 +174,7 @@ async def api_chat_stream(request: Request) -> StreamingResponse: raise HTTPException(status_code=400, detail="messages array is required") # Load config to determine model capabilities and overrides - machine = load_machine_config(CONFIG_DIR / "machine.json") or {} + machine = load_machine_config(DEFAULT_MACHINE_CONFIG_PATH) or {} stream_ctx = resolve_stream_model_context(payload, machine) model_type = stream_ctx["model_type"] selected_name = stream_ctx["selected_name"] @@ -180,6 +184,7 @@ async def api_chat_stream(request: Request) -> StreamingResponse: timeout_s = stream_ctx["timeout_s"] is_multimodal = stream_ctx["is_multimodal"] supports_function_calling = stream_ctx["supports_function_calling"] + chosen = stream_ctx["chosen"] # Inject images if referenced in the last user message and supported if is_multimodal: @@ -204,18 +209,71 @@ async def api_chat_stream(request: Request) -> StreamingResponse: headers["Authorization"] = f"Bearer {api_key}" temperature, max_tokens = resolve_story_llm_prefs( - config_dir=CONFIG_DIR, + config_dir=DEFAULT_STORY_CONFIG_PATH.parent, active_project_dir=get_active_project_dir(), ) + def _to_float(value): + try: + if value is None or value == "": + return None + return float(value) + except Exception: + return None + + def _to_int(value): + try: + if value is None or value == "": + return None + return int(value) + except Exception: + return None + + model_temperature = _to_float((chosen or {}).get("temperature")) + if model_temperature is None: + model_temperature = temperature + + model_max_tokens = _to_int((chosen or {}).get("max_tokens")) + if model_max_tokens is None: + model_max_tokens = max_tokens + + extra_body: Dict[str, Any] = {} + for key in ( + "top_p", + "presence_penalty", + "frequency_penalty", + "seed", + "top_k", + "min_p", + ): + value = (chosen or {}).get(key) + if value is not None: + extra_body[key] = value + + stop = (chosen or {}).get("stop") + if isinstance(stop, list) and stop: + extra_body["stop"] = [str(entry) for entry in stop] + + raw_extra_body = (chosen or {}).get("extra_body") + if isinstance(raw_extra_body, str) and raw_extra_body.strip(): + try: + parsed_extra = _json.loads(raw_extra_body) + if isinstance(parsed_extra, dict): + extra_body.update(parsed_extra) + except Exception: + # Invalid JSON is ignored by design so users can save drafts safely. + pass + body: Dict[str, Any] = { "model": model_id, "messages": req_messages, - "temperature": temperature, + "temperature": model_temperature, "stream": True, } - if isinstance(max_tokens, int): - body["max_tokens"] = max_tokens + if isinstance(model_max_tokens, int): + body["max_tokens"] = model_max_tokens + if extra_body: + body.update(extra_body) # Pass through OpenAI tool-calling fields if provided tool_choice = None @@ -247,8 +305,9 @@ async def _gen(): supports_function_calling=supports_function_calling, tools=story_tools, tool_choice=tool_choice if tool_choice != "none" else None, - temperature=temperature, - max_tokens=max_tokens, + temperature=model_temperature, + max_tokens=model_max_tokens, + extra_body=extra_body, log_entry=log_entry, skip_validation=True, # Trust configured models ): diff --git a/src/augmentedquill/api/v1/settings.py b/src/augmentedquill/api/v1/settings.py index 75a02b5..2d93161 100644 --- a/src/augmentedquill/api/v1/settings.py +++ b/src/augmentedquill/api/v1/settings.py @@ -16,10 +16,13 @@ from augmentedquill.core.config import ( load_machine_config, + load_model_presets_config, save_story_config, CURRENT_SCHEMA_VERSION, BASE_DIR, - CONFIG_DIR, + DEFAULT_MACHINE_CONFIG_PATH, + DEFAULT_STORY_CONFIG_PATH, + DEFAULT_MODEL_PRESETS_PATH, ) from augmentedquill.services.projects.projects import get_active_project_dir from augmentedquill.core.prompts import ( @@ -79,8 +82,8 @@ async def api_settings_post(request: Request) -> JSONResponse: try: active = get_active_project_dir() - story_path = (active / "story.json") if active else (CONFIG_DIR / "story.json") - machine_path = CONFIG_DIR / "machine.json" + story_path = (active / "story.json") if active else DEFAULT_STORY_CONFIG_PATH + machine_path = DEFAULT_MACHINE_CONFIG_PATH story_path.parent.mkdir(parents=True, exist_ok=True) machine_path.parent.mkdir(parents=True, exist_ok=True) save_story_config(story_path, story_cfg) @@ -94,7 +97,7 @@ async def api_settings_post(request: Request) -> JSONResponse: @router.get("/prompts") async def api_prompts_get(model_name: str | None = None) -> JSONResponse: """Get all resolved prompts (defaults + global overrides + model overrides).""" - machine_config = load_machine_config(CONFIG_DIR / "machine.json") or {} + machine_config = load_machine_config(DEFAULT_MACHINE_CONFIG_PATH) or {} if not model_name: model_name = machine_config.get("openai", {}).get("selected") @@ -154,6 +157,16 @@ async def api_machine_test(request: Request) -> JSONResponse: ) +@router.get("/machine/presets") +async def api_machine_presets_get() -> JSONResponse: + """Return model preset database used by Machine Settings UI.""" + data = load_model_presets_config(DEFAULT_MODEL_PRESETS_PATH) or {} + presets = data.get("presets") if isinstance(data, dict) else [] + if not isinstance(presets, list): + presets = [] + return JSONResponse(status_code=200, content={"presets": presets}) + + @router.post("/machine/test_model") async def api_machine_test_model(request: Request) -> JSONResponse: """Test whether a model is available for base_url + api_key. @@ -236,7 +249,7 @@ async def api_machine_test_model(request: Request) -> JSONResponse: @router.put("/machine") async def api_machine_put(request: Request) -> JSONResponse: - """Persist machine config to resources/config/machine.json. + """Persist machine config to runtime user config path. Body: { openai: { models: [{name, base_url, api_key?, timeout_s?, model}], selected? } } Returns: { ok: bool, detail?: str } @@ -255,7 +268,7 @@ async def api_machine_put(request: Request) -> JSONResponse: ) try: - machine_path = CONFIG_DIR / "machine.json" + machine_path = DEFAULT_MACHINE_CONFIG_PATH machine_path.parent.mkdir(parents=True, exist_ok=True) machine_path.write_text(_json.dumps(machine_cfg, indent=2), encoding="utf-8") except Exception as e: @@ -278,7 +291,7 @@ async def api_story_summary_put(request: Request) -> JSONResponse: summary = payload.get("summary", "") try: active = get_active_project_dir() - story_path = (active / "story.json") if active else (CONFIG_DIR / "story.json") + story_path = (active / "story.json") if active else DEFAULT_STORY_CONFIG_PATH update_story_field(story_path, "story_summary", summary) except Exception as e: return JSONResponse( @@ -303,7 +316,7 @@ async def api_story_tags_put(request: Request) -> JSONResponse: try: active = get_active_project_dir() - story_path = (active / "story.json") if active else (CONFIG_DIR / "story.json") + story_path = (active / "story.json") if active else DEFAULT_STORY_CONFIG_PATH update_story_field(story_path, "tags", tags) except Exception as e: return error_json(f"Failed to update story tags: {e}", status_code=500) @@ -316,10 +329,10 @@ async def update_story_config(request: Request): """Update the story config to the latest version.""" try: active = get_active_project_dir() - story_path = (active / "story.json") if active else (CONFIG_DIR / "story.json") + story_path = (active / "story.json") if active else DEFAULT_STORY_CONFIG_PATH ok, message = run_story_config_update( base_dir=BASE_DIR, - config_dir=CONFIG_DIR, + config_dir=DEFAULT_STORY_CONFIG_PATH.parent, story_path=story_path, current_schema_version=CURRENT_SCHEMA_VERSION, ) diff --git a/src/augmentedquill/api/v1/sourcebook.py b/src/augmentedquill/api/v1/sourcebook.py index 05e4716..172684f 100644 --- a/src/augmentedquill/api/v1/sourcebook.py +++ b/src/augmentedquill/api/v1/sourcebook.py @@ -54,7 +54,7 @@ class SourcebookEntryUpdate(BaseModel): async def get_sourcebook() -> List[SourcebookEntry]: active = get_active_project_dir() if not active: - raise HTTPException(status_code=400, detail="No active project") + return [] return [SourcebookEntry(**entry) for entry in sourcebook_list_entries()] diff --git a/src/augmentedquill/core/config.py b/src/augmentedquill/core/config.py index a971ae3..52b0d95 100644 --- a/src/augmentedquill/core/config.py +++ b/src/augmentedquill/core/config.py @@ -10,8 +10,8 @@ Configuration loading utilities for AugmentedQuill. Conventions: -- Machine-specific config: resources/config/machine.json -- Story-specific config: resources/config/story.json +- Runtime user config: data/config/{machine,story,projects}.json +- Project-shipped config assets: resources/config/*.json (e.g., model presets) - Environment variables override JSON values. - JSON values can reference environment variables using ${VAR_NAME} placeholders. @@ -35,14 +35,18 @@ CONFIG_DIR = BASE_DIR / "resources" / "config" SCHEMAS_DIR = BASE_DIR / "resources" / "schemas" RESOURCES_DIR = BASE_DIR / "resources" -DATA_DIR = BASE_DIR / "data" +_ENV_USER_DATA_DIR = os.getenv("AUGQ_USER_DATA_DIR") +DATA_DIR = Path(_ENV_USER_DATA_DIR) if _ENV_USER_DATA_DIR else BASE_DIR / "data" PROJECTS_ROOT = DATA_DIR / "projects" LOGS_DIR = DATA_DIR / "logs" STATIC_DIR = BASE_DIR / "static" CURRENT_SCHEMA_VERSION = 2 -DEFAULT_MACHINE_CONFIG_PATH = CONFIG_DIR / "machine.json" -DEFAULT_STORY_CONFIG_PATH = CONFIG_DIR / "story.json" +USER_CONFIG_DIR = DATA_DIR / "config" +DEFAULT_MACHINE_CONFIG_PATH = USER_CONFIG_DIR / "machine.json" +DEFAULT_STORY_CONFIG_PATH = USER_CONFIG_DIR / "story.json" +DEFAULT_PROJECTS_REGISTRY_PATH = USER_CONFIG_DIR / "projects.json" +DEFAULT_MODEL_PRESETS_PATH = CONFIG_DIR / "model_presets.json" def _get_story_schema(version: int) -> Dict[str, Any]: @@ -184,3 +188,47 @@ def save_story_config(path: os.PathLike[str] | str, config: Dict[str, Any]) -> N with p.open("w", encoding="utf-8") as f: json.dump(clean_config, f, indent=2, ensure_ascii=False) + + +def load_model_presets_config( + path: os.PathLike[str] | str | None = DEFAULT_MODEL_PRESETS_PATH, +) -> Dict[str, Any]: + """Load global model preset database JSON.""" + return load_json_file(path) + + +def ensure_runtime_user_config_files() -> None: + """Create missing runtime user config files with safe defaults. + + This keeps first startup usable without manual setup. + """ + USER_CONFIG_DIR.mkdir(parents=True, exist_ok=True) + + machine_path = DEFAULT_MACHINE_CONFIG_PATH + if not machine_path.exists(): + machine_path.write_text("{}\n", encoding="utf-8") + + story_path = DEFAULT_STORY_CONFIG_PATH + if not story_path.exists(): + story_payload: Dict[str, Any] = { + "project_title": "Untitled Project", + "project_type": "novel", + "chapters": [], + "format": "markdown", + "metadata": {"version": CURRENT_SCHEMA_VERSION}, + "llm_prefs": {"temperature": 0.7, "max_tokens": 2048}, + "sourcebook": {}, + "story_summary": "", + "tags": [], + } + story_path.write_text( + json.dumps(story_payload, indent=2, ensure_ascii=False) + "\n", + encoding="utf-8", + ) + + projects_registry_path = DEFAULT_PROJECTS_REGISTRY_PATH + if not projects_registry_path.exists(): + projects_registry_path.write_text( + json.dumps({"current": "", "recent": []}, indent=2) + "\n", + encoding="utf-8", + ) diff --git a/src/augmentedquill/main.py b/src/augmentedquill/main.py index a875554..eee5be7 100644 --- a/src/augmentedquill/main.py +++ b/src/augmentedquill/main.py @@ -22,7 +22,12 @@ from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware -from augmentedquill.core.config import load_machine_config, STATIC_DIR, CONFIG_DIR +from augmentedquill.core.config import ( + load_machine_config, + STATIC_DIR, + DEFAULT_MACHINE_CONFIG_PATH, + ensure_runtime_user_config_files, +) from augmentedquill.services.exceptions import ServiceError # Import API routers @@ -43,6 +48,7 @@ def create_app() -> FastAPI: """ app = FastAPI(title="AugmentedQuill") + ensure_runtime_user_config_files() # Dynamic CORS origin handler to support variable ports async def get_origins(request: Request) -> list[str]: @@ -87,7 +93,7 @@ async def get_origins(request: Request) -> list[str]: ) api_v1_router.add_api_route( "/machine", - endpoint=lambda: load_machine_config(CONFIG_DIR / "machine.json") or {}, + endpoint=lambda: load_machine_config(DEFAULT_MACHINE_CONFIG_PATH) or {}, methods=["GET"], ) diff --git a/src/augmentedquill/services/llm/llm.py b/src/augmentedquill/services/llm/llm.py index 4706239..0905777 100644 --- a/src/augmentedquill/services/llm/llm.py +++ b/src/augmentedquill/services/llm/llm.py @@ -20,7 +20,7 @@ import httpx -from augmentedquill.core.config import load_machine_config, CONFIG_DIR +from augmentedquill.core.config import load_machine_config, DEFAULT_MACHINE_CONFIG_PATH from augmentedquill.services.llm import llm_logging as _llm_logging from augmentedquill.services.llm import llm_stream_ops as _llm_stream_ops from augmentedquill.services.llm import llm_completion_ops as _llm_completion_ops @@ -39,7 +39,7 @@ def get_selected_model_name( payload: Dict[str, Any], model_type: str | None = None ) -> str | None: """Get the selected model name based on payload and model_type.""" - machine = load_machine_config(CONFIG_DIR / "machine.json") or {} + machine = load_machine_config(DEFAULT_MACHINE_CONFIG_PATH) or {} openai_cfg: Dict[str, Any] = machine.get("openai") or {} selected_name = payload.get("model_name") @@ -67,7 +67,7 @@ def resolve_openai_credentials( 2. Payload overrides: base_url, api_key, model, timeout_s or model_name (by name) 3. machine.json -> openai.models[] (selected by name based on model_type) """ - machine = load_machine_config(CONFIG_DIR / "machine.json") or {} + machine = load_machine_config(DEFAULT_MACHINE_CONFIG_PATH) or {} openai_cfg: Dict[str, Any] = machine.get("openai") or {} selected_name = get_selected_model_name(payload, model_type) @@ -123,6 +123,7 @@ async def unified_chat_stream( tool_choice: str | None = None, temperature: float = 0.7, max_tokens: int | None = None, + extra_body: dict | None = None, log_entry: dict | None = None, skip_validation: bool = False, ) -> AsyncIterator[dict]: @@ -140,6 +141,7 @@ async def unified_chat_stream( tool_choice=tool_choice, temperature=temperature, max_tokens=max_tokens, + extra_body=extra_body, log_entry=log_entry, skip_validation=skip_validation, ): @@ -158,6 +160,7 @@ async def unified_chat_complete( tool_choice: str | None = None, temperature: float = 0.7, max_tokens: int | None = None, + extra_body: dict | None = None, skip_validation: bool = False, ) -> dict: """Unified Chat Complete.""" @@ -173,6 +176,7 @@ async def unified_chat_complete( tool_choice=tool_choice, temperature=temperature, max_tokens=max_tokens, + extra_body=extra_body, skip_validation=skip_validation, ) @@ -184,6 +188,8 @@ async def openai_chat_complete( api_key: str | None, model_id: str, timeout_s: int, + temperature: float | None = None, + max_tokens: int | None = None, extra_body: dict | None = None, skip_validation: bool = False, ) -> dict: @@ -195,6 +201,8 @@ async def openai_chat_complete( api_key=api_key, model_id=model_id, timeout_s=timeout_s, + temperature=temperature, + max_tokens=max_tokens, extra_body=extra_body, skip_validation=skip_validation, ) @@ -208,6 +216,8 @@ async def openai_completions( model_id: str, timeout_s: int, n: int = 1, + temperature: float | None = None, + max_tokens: int | None = None, extra_body: dict | None = None, skip_validation: bool = False, ) -> dict: @@ -220,6 +230,8 @@ async def openai_completions( model_id=model_id, timeout_s=timeout_s, n=n, + temperature=temperature, + max_tokens=max_tokens, extra_body=extra_body, skip_validation=skip_validation, ) @@ -232,6 +244,9 @@ async def openai_chat_complete_stream( api_key: str | None, model_id: str, timeout_s: int, + temperature: float | None = None, + max_tokens: int | None = None, + extra_body: dict | None = None, skip_validation: bool = False, ) -> AsyncIterator[str]: """Openai Chat Complete Stream.""" @@ -242,6 +257,9 @@ async def openai_chat_complete_stream( api_key=api_key, model_id=model_id, timeout_s=timeout_s, + temperature=temperature, + max_tokens=max_tokens, + extra_body=extra_body, skip_validation=skip_validation, ): yield chunk @@ -254,6 +272,8 @@ async def openai_completions_stream( api_key: str | None, model_id: str, timeout_s: int, + temperature: float | None = None, + max_tokens: int | None = None, extra_body: dict | None = None, skip_validation: bool = False, ) -> AsyncIterator[str]: @@ -265,6 +285,8 @@ async def openai_completions_stream( api_key=api_key, model_id=model_id, timeout_s=timeout_s, + temperature=temperature, + max_tokens=max_tokens, extra_body=extra_body, skip_validation=skip_validation, ): diff --git a/src/augmentedquill/services/llm/llm_completion_ops.py b/src/augmentedquill/services/llm/llm_completion_ops.py index de48e4c..d123b0b 100644 --- a/src/augmentedquill/services/llm/llm_completion_ops.py +++ b/src/augmentedquill/services/llm/llm_completion_ops.py @@ -12,12 +12,14 @@ from typing import Any, Dict, AsyncIterator import datetime import os +import json import httpx from augmentedquill.core.config import ( load_story_config, - CONFIG_DIR, + DEFAULT_STORY_CONFIG_PATH, + DEFAULT_MACHINE_CONFIG_PATH, load_machine_config, ) from augmentedquill.services.projects.projects import get_active_project_dir @@ -61,8 +63,7 @@ def _validate_base_url(base_url: str, skip_validation: bool = False) -> None: return # 2. Check machine.json models - config_path = os.path.join(CONFIG_DIR, "machine.json") - machine_config = load_machine_config(config_path) + machine_config = load_machine_config(DEFAULT_MACHINE_CONFIG_PATH) if machine_config: for provider in ["openai", "anthropic", "google"]: all_models = machine_config.get(provider, {}).get("models", []) @@ -115,6 +116,102 @@ def _prepare_llm_request( return url, headers, body +def _resolve_temperature_max_tokens( + temperature: float | None, + max_tokens: int | None, + model_cfg: dict | None = None, +) -> tuple[float, int | None]: + """Resolve runtime temperature/max_tokens with story defaults fallback.""" + if temperature is not None and max_tokens is not None: + return float(temperature), int(max_tokens) + + model_temperature = None + model_max_tokens = None + if isinstance(model_cfg, dict): + try: + if model_cfg.get("temperature") not in (None, ""): + model_temperature = float(model_cfg.get("temperature")) + except Exception: + model_temperature = None + try: + if model_cfg.get("max_tokens") not in (None, ""): + model_max_tokens = int(model_cfg.get("max_tokens")) + except Exception: + model_max_tokens = None + + if temperature is None: + temperature = model_temperature + if max_tokens is None: + max_tokens = model_max_tokens + + if temperature is not None and max_tokens is not None: + return float(temperature), int(max_tokens) + + story_temperature, story_max_tokens = get_story_llm_preferences( + config_dir=DEFAULT_STORY_CONFIG_PATH.parent, + get_active_project_dir=get_active_project_dir, + load_story_config=load_story_config, + ) + return ( + float(temperature) if temperature is not None else story_temperature, + int(max_tokens) if max_tokens is not None else story_max_tokens, + ) + + +def _resolve_machine_model_cfg(base_url: str, model_id: str) -> dict: + """Resolve machine model entry matching base_url + model_id.""" + machine_config = load_machine_config(DEFAULT_MACHINE_CONFIG_PATH) or {} + openai_cfg = ( + machine_config.get("openai") if isinstance(machine_config, dict) else {} + ) + models = openai_cfg.get("models") if isinstance(openai_cfg, dict) else [] + if not isinstance(models, list): + return {} + for model in models: + if not isinstance(model, dict): + continue + if str(model.get("base_url") or "") != str(base_url or ""): + continue + if str(model.get("model") or "") != str(model_id or ""): + continue + return model + return {} + + +def _build_model_extra_body(model_cfg: dict) -> dict: + """Build extra_body payload from machine model parameters.""" + if not isinstance(model_cfg, dict): + return {} + + extra: dict = {} + for key in ( + "top_p", + "presence_penalty", + "frequency_penalty", + "seed", + "top_k", + "min_p", + ): + value = model_cfg.get(key) + if value is not None: + extra[key] = value + + stop = model_cfg.get("stop") + if isinstance(stop, list) and stop: + extra["stop"] = [str(entry) for entry in stop] + + raw_extra_body = model_cfg.get("extra_body") + if isinstance(raw_extra_body, str) and raw_extra_body.strip(): + try: + parsed = json.loads(raw_extra_body) + if isinstance(parsed, dict): + extra.update(parsed) + except Exception: + pass + + return extra + + async def unified_chat_complete( *, messages: list[dict], @@ -127,14 +224,15 @@ async def unified_chat_complete( tool_choice: str | None = None, temperature: float = 0.7, max_tokens: int | None = None, + extra_body: dict | None = None, skip_validation: bool = False, ) -> dict: """Execute a non-streaming chat completion and normalize tool/thinking output.""" - extra_body = {} + merged_extra_body = dict(extra_body or {}) if supports_function_calling and tools and tool_choice != "none": - extra_body["tools"] = tools + merged_extra_body["tools"] = tools if tool_choice: - extra_body["tool_choice"] = tool_choice + merged_extra_body["tool_choice"] = tool_choice resp_json = await openai_chat_complete( messages=messages, @@ -142,7 +240,9 @@ async def unified_chat_complete( api_key=api_key, model_id=model_id, timeout_s=timeout_s, - extra_body=extra_body, + temperature=temperature, + max_tokens=max_tokens, + extra_body=merged_extra_body, skip_validation=skip_validation, ) @@ -223,15 +323,16 @@ async def openai_chat_complete( api_key: str | None, model_id: str, timeout_s: int, + temperature: float | None = None, + max_tokens: int | None = None, extra_body: dict | None = None, skip_validation: bool = False, ) -> dict: """Call the OpenAI-compatible chat completions endpoint and return JSON.""" _validate_base_url(base_url, skip_validation=skip_validation) - temperature, max_tokens = get_story_llm_preferences( - config_dir=CONFIG_DIR, - get_active_project_dir=get_active_project_dir, - load_story_config=load_story_config, + model_cfg = _resolve_machine_model_cfg(base_url, model_id) + temperature, max_tokens = _resolve_temperature_max_tokens( + temperature, max_tokens, model_cfg ) url = str(base_url).rstrip("/") + "/chat/completions" @@ -244,6 +345,9 @@ async def openai_chat_complete( } if isinstance(max_tokens, int): body["max_tokens"] = max_tokens + model_extra = _build_model_extra_body(model_cfg) + if model_extra: + body.update(model_extra) if extra_body: body.update(extra_body) @@ -258,15 +362,16 @@ async def openai_completions( model_id: str, timeout_s: int, n: int = 1, + temperature: float | None = None, + max_tokens: int | None = None, extra_body: dict | None = None, skip_validation: bool = False, ) -> dict: """Call the OpenAI-compatible text completions endpoint and return JSON.""" _validate_base_url(base_url, skip_validation=skip_validation) - temperature, max_tokens = get_story_llm_preferences( - config_dir=CONFIG_DIR, - get_active_project_dir=get_active_project_dir, - load_story_config=load_story_config, + model_cfg = _resolve_machine_model_cfg(base_url, model_id) + temperature, max_tokens = _resolve_temperature_max_tokens( + temperature, max_tokens, model_cfg ) url = str(base_url).rstrip("/") + "/completions" @@ -280,6 +385,9 @@ async def openai_completions( } if isinstance(max_tokens, int): body["max_tokens"] = max_tokens + model_extra = _build_model_extra_body(model_cfg) + if model_extra: + body.update(model_extra) if extra_body: body.update(extra_body) @@ -293,15 +401,17 @@ async def openai_chat_complete_stream( api_key: str | None, model_id: str, timeout_s: int, + temperature: float | None = None, + max_tokens: int | None = None, + extra_body: dict | None = None, skip_validation: bool = False, ) -> AsyncIterator[str]: """Stream content chunks from the chat completions endpoint.""" _validate_base_url(base_url, skip_validation=skip_validation) url = str(base_url).rstrip("/") + "/chat/completions" - temperature, max_tokens = get_story_llm_preferences( - config_dir=CONFIG_DIR, - get_active_project_dir=get_active_project_dir, - load_story_config=load_story_config, + model_cfg = _resolve_machine_model_cfg(base_url, model_id) + temperature, max_tokens = _resolve_temperature_max_tokens( + temperature, max_tokens, model_cfg ) headers = build_headers(api_key) @@ -314,6 +424,11 @@ async def openai_chat_complete_stream( } if isinstance(max_tokens, int): body["max_tokens"] = max_tokens + model_extra = _build_model_extra_body(model_cfg) + if model_extra: + body.update(model_extra) + if extra_body: + body.update(extra_body) # Security: Ensure sensitive headers (like Authorization) are masked BEFORE logging safe_log_headers = { @@ -373,16 +488,17 @@ async def openai_completions_stream( api_key: str | None, model_id: str, timeout_s: int, + temperature: float | None = None, + max_tokens: int | None = None, extra_body: dict | None = None, skip_validation: bool = False, ) -> AsyncIterator[str]: """Stream content chunks from the text completions endpoint.""" _validate_base_url(base_url, skip_validation=skip_validation) url = str(base_url).rstrip("/") + "/completions" - temperature, max_tokens = get_story_llm_preferences( - config_dir=CONFIG_DIR, - get_active_project_dir=get_active_project_dir, - load_story_config=load_story_config, + model_cfg = _resolve_machine_model_cfg(base_url, model_id) + temperature, max_tokens = _resolve_temperature_max_tokens( + temperature, max_tokens, model_cfg ) headers = build_headers(api_key) @@ -395,6 +511,9 @@ async def openai_completions_stream( } if isinstance(max_tokens, int): body["max_tokens"] = max_tokens + model_extra = _build_model_extra_body(model_cfg) + if model_extra: + body.update(model_extra) if extra_body: body.update(extra_body) diff --git a/src/augmentedquill/services/llm/llm_stream_ops.py b/src/augmentedquill/services/llm/llm_stream_ops.py index 07938bd..c12f33f 100644 --- a/src/augmentedquill/services/llm/llm_stream_ops.py +++ b/src/augmentedquill/services/llm/llm_stream_ops.py @@ -20,6 +20,7 @@ CONFIG_DIR, load_machine_config, ) +from augmentedquill.services.llm.llm_logging import add_llm_log from augmentedquill.utils.stream_helpers import ChannelFilter from augmentedquill.utils.llm_parsing import ( parse_complete_assistant_output, @@ -92,6 +93,7 @@ async def unified_chat_stream( tool_choice: str | None = None, temperature: float = 0.7, max_tokens: int | None = None, + extra_body: dict | None = None, log_entry: dict | None = None, skip_validation: bool = False, ) -> AsyncIterator[dict]: @@ -110,6 +112,8 @@ async def unified_chat_stream( } if isinstance(max_tokens, int): body["max_tokens"] = max_tokens + if isinstance(extra_body, dict): + body.update(extra_body) if supports_function_calling and tools and tool_choice != "none": body["tools"] = tools @@ -189,6 +193,7 @@ async def unified_chat_stream( error_data = _json.loads(error_content) if log_entry: log_entry["response"]["error"] = error_data + add_llm_log(log_entry) yield { "error": "Upstream error", "status": resp.status_code, @@ -198,6 +203,7 @@ async def unified_chat_stream( err_text = error_content.decode("utf-8", errors="ignore") if log_entry: log_entry["response"]["error"] = err_text + add_llm_log(log_entry) yield { "error": "Upstream error", "status": resp.status_code, @@ -250,6 +256,15 @@ async def unified_chat_stream( yield {"done": True} except Exception as e: + if log_entry: + log_entry["timestamp_end"] = ( + datetime.datetime.now().isoformat() + ) + log_entry["response"]["error_detail"] = str(e) + log_entry["response"][ + "error" + ] = "Failed to parse non-stream response" + add_llm_log(log_entry) yield { "error": "Failed to parse response", "message": f"An error occurred while processing the response: {e}", @@ -290,10 +305,6 @@ async def unified_chat_stream( datetime.datetime.now().isoformat() ) # Force re-logging on completion so we get full_content and chunks - from augmentedquill.services.llm.llm_logging import ( - add_llm_log, - ) - add_llm_log(log_entry) yield {"done": True} @@ -341,10 +352,16 @@ async def unified_chat_stream( break except Exception as e: + err_text = str(e).strip() or f"{type(e).__name__}: {repr(e)}" if log_entry: - log_entry["response"]["error_detail"] = str(e) + log_entry["timestamp_end"] = datetime.datetime.now().isoformat() + log_entry["response"]["error_detail"] = err_text log_entry["response"][ "error" - ] = f"An internal error occurred during the LLM request: {e}" - yield {"error": "Connection error", "message": f"An error occurred: {e}."} + ] = f"An internal error occurred during the LLM request: {err_text}" + add_llm_log(log_entry) + yield { + "error": "Connection error", + "message": f"An error occurred: {err_text}.", + } break diff --git a/src/augmentedquill/services/projects/projects.py b/src/augmentedquill/services/projects/projects.py index a683eee..0c638f5 100644 --- a/src/augmentedquill/services/projects/projects.py +++ b/src/augmentedquill/services/projects/projects.py @@ -53,8 +53,8 @@ select_project_under_root, ) from augmentedquill.core.config import ( - CONFIG_DIR, PROJECTS_ROOT, + DEFAULT_PROJECTS_REGISTRY_PATH, ) @@ -75,19 +75,19 @@ class ProjectInfo: def load_registry() -> Dict: return load_registry_from_path( - Path(os.getenv("AUGQ_PROJECTS_REGISTRY", str(CONFIG_DIR / "projects.json"))) + Path(os.getenv("AUGQ_PROJECTS_REGISTRY", str(DEFAULT_PROJECTS_REGISTRY_PATH))) ) def set_active_project(path: Path) -> None: reg = load_registry() current, recent = set_active_project_in_registry( - Path(os.getenv("AUGQ_PROJECTS_REGISTRY", str(CONFIG_DIR / "projects.json"))), + Path(os.getenv("AUGQ_PROJECTS_REGISTRY", str(DEFAULT_PROJECTS_REGISTRY_PATH))), path, reg, ) save_registry_to_path( - Path(os.getenv("AUGQ_PROJECTS_REGISTRY", str(CONFIG_DIR / "projects.json"))), + Path(os.getenv("AUGQ_PROJECTS_REGISTRY", str(DEFAULT_PROJECTS_REGISTRY_PATH))), current, recent, ) @@ -117,7 +117,7 @@ def delete_project(name: str) -> Tuple[bool, str]: if ok: save_registry_to_path( Path( - os.getenv("AUGQ_PROJECTS_REGISTRY", str(CONFIG_DIR / "projects.json")) + os.getenv("AUGQ_PROJECTS_REGISTRY", str(DEFAULT_PROJECTS_REGISTRY_PATH)) ), current, recent, diff --git a/src/augmentedquill/services/settings/settings_api_ops.py b/src/augmentedquill/services/settings/settings_api_ops.py index 349b283..f0da4a5 100644 --- a/src/augmentedquill/services/settings/settings_api_ops.py +++ b/src/augmentedquill/services/settings/settings_api_ops.py @@ -10,6 +10,7 @@ from __future__ import annotations from pathlib import Path +import json from augmentedquill.core.config import load_story_config, save_story_config from augmentedquill.services.chapters.chapter_helpers import _normalize_chapter_entry @@ -86,6 +87,19 @@ def clean_machine_openai_cfg_for_put( selected = ( (openai_cfg.get("selected") or "") if isinstance(openai_cfg, dict) else "" ) + selected_chat = ( + (openai_cfg.get("selected_chat") or "") if isinstance(openai_cfg, dict) else "" + ) + selected_writing = ( + (openai_cfg.get("selected_writing") or "") + if isinstance(openai_cfg, dict) + else "" + ) + selected_editing = ( + (openai_cfg.get("selected_editing") or "") + if isinstance(openai_cfg, dict) + else "" + ) if not (isinstance(models, list) and models): return None, None, "At least one model must be configured in openai.models[]." @@ -115,6 +129,43 @@ def clean_machine_openai_cfg_for_put( except Exception: timeout_s_int = 60 + def _to_optional_float(value): + if value in (None, ""): + return None + try: + return float(value) + except Exception: + return None + + def _to_optional_int(value): + if value in (None, ""): + return None + try: + return int(value) + except Exception: + return None + + stop_value = model.get("stop") + if isinstance(stop_value, list): + stop_clean = [str(entry) for entry in stop_value if str(entry).strip()] + elif isinstance(stop_value, str): + stop_clean = [ + entry.strip() for entry in stop_value.split("\n") if entry.strip() + ] + else: + stop_clean = [] + + extra_body_value = model.get("extra_body") + if extra_body_value is None: + extra_body_clean = "" + elif isinstance(extra_body_value, str): + extra_body_clean = extra_body_value + else: + try: + extra_body_clean = json.dumps(extra_body_value) + except Exception: + extra_body_clean = "" + cleaned_models.append( { "name": name, @@ -122,6 +173,18 @@ def clean_machine_openai_cfg_for_put( "api_key": api_key, "timeout_s": timeout_s_int, "model": model_id, + "temperature": _to_optional_float(model.get("temperature")), + "top_p": _to_optional_float(model.get("top_p")), + "max_tokens": _to_optional_int(model.get("max_tokens")), + "presence_penalty": _to_optional_float(model.get("presence_penalty")), + "frequency_penalty": _to_optional_float(model.get("frequency_penalty")), + "stop": stop_clean, + "seed": _to_optional_int(model.get("seed")), + "top_k": _to_optional_int(model.get("top_k")), + "min_p": _to_optional_float(model.get("min_p")), + "extra_body": extra_body_clean, + "preset_id": (model.get("preset_id") or None), + "writing_warning": (model.get("writing_warning") or None), "is_multimodal": model.get("is_multimodal"), "supports_function_calling": model.get("supports_function_calling"), "prompt_overrides": prompt_overrides, @@ -141,7 +204,28 @@ def clean_machine_openai_cfg_for_put( elif selected not in [model.get("name") for model in cleaned_models]: selected = cleaned_models[0].get("name", "") - return {"openai": {"models": cleaned_models, "selected": selected}}, selected, None + available_names = [model.get("name") for model in cleaned_models] + + if not selected_chat or selected_chat not in available_names: + selected_chat = selected + if not selected_writing or selected_writing not in available_names: + selected_writing = selected + if not selected_editing or selected_editing not in available_names: + selected_editing = selected + + return ( + { + "openai": { + "models": cleaned_models, + "selected": selected, + "selected_chat": selected_chat, + "selected_writing": selected_writing, + "selected_editing": selected_editing, + } + }, + selected, + None, + ) def update_story_field(story_path: Path, field: str, value) -> None: diff --git a/src/augmentedquill/utils/llm_utils.py b/src/augmentedquill/utils/llm_utils.py index 3434d36..526373a 100644 --- a/src/augmentedquill/utils/llm_utils.py +++ b/src/augmentedquill/utils/llm_utils.py @@ -10,19 +10,37 @@ Common LLM-related utility functions, including capability verification and URL normalization. """ -import httpx import asyncio +import hashlib +import time + +import httpx # 1x1 transparent pixel PIXEL_B64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII=" +_CAPABILITY_CACHE_TTL_S = 3600 +_capability_cache: dict[tuple[str, str, str], tuple[float, dict]] = {} +_capability_inflight: dict[tuple[str, str, str], asyncio.Task] = {} +_capability_lock = asyncio.Lock() -async def verify_model_capabilities( - base_url: str, api_key: str | None, model_id: str, timeout_s: int = 10 + +def _cache_key( + base_url: str, api_key: str | None, model_id: str +) -> tuple[str, str, str]: + """Build stable cache key for provider/model/account scope.""" + normalized_base_url = str(base_url or "").strip().rstrip("/").lower() + normalized_model_id = str(model_id or "").strip() + api_key_hash = ( + hashlib.sha256(api_key.encode("utf-8")).hexdigest() if api_key else "" + ) + return normalized_base_url, normalized_model_id, api_key_hash + + +async def _probe_model_capabilities( + base_url: str, api_key: str | None, model_id: str, timeout_s: int ) -> dict: - """ - Dynamically tests the model for Vision and Function Calling capabilities by sending minimal requests. - """ + """Execute remote capability probes against chat-completions endpoint.""" url = str(base_url or "").strip().rstrip("/") + "/chat/completions" headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} headers["Content-Type"] = "application/json" @@ -49,7 +67,6 @@ async def check_vision(client): "max_tokens": 1, } response = await client.post(url, json=payload, headers=headers) - # If 200 OK, vision is supported. return response.status_code == 200 except Exception: return False @@ -74,18 +91,19 @@ async def check_function_calling(client): "max_tokens": 1, } response = await client.post(url, json=payload, headers=headers) - - # If 200 OK, we assume the API handled the 'tools' parameter gracefully (supported) - # If 400, it usually means 'tools' was not recognized. return response.status_code == 200 except Exception: return False - async with httpx.AsyncClient(timeout=timeout_s) as client: - # Run tests in parallel - results = await asyncio.gather( - check_vision(client), check_function_calling(client), return_exceptions=True - ) + try: + async with httpx.AsyncClient(timeout=timeout_s) as client: + results = await asyncio.gather( + check_vision(client), + check_function_calling(client), + return_exceptions=True, + ) + except Exception: + results = [False, False] is_multimodal = results[0] if isinstance(results[0], bool) else False supports_function_calling = results[1] if isinstance(results[1], bool) else False @@ -94,3 +112,53 @@ async def check_function_calling(client): "is_multimodal": is_multimodal, "supports_function_calling": supports_function_calling, } + + +def _clear_model_capabilities_cache_for_tests() -> None: + """Reset in-memory capability cache/inflight registry (test-only helper).""" + _capability_cache.clear() + _capability_inflight.clear() + + +async def verify_model_capabilities( + base_url: str, + api_key: str | None, + model_id: str, + timeout_s: int = 10, + cache_ttl_s: int = _CAPABILITY_CACHE_TTL_S, +) -> dict: + """ + Dynamically tests the model for Vision and Function Calling capabilities by sending minimal requests. + """ + key = _cache_key(base_url=base_url, api_key=api_key, model_id=model_id) + now = time.monotonic() + + async with _capability_lock: + cache_entry = _capability_cache.get(key) + if cache_entry and cache_entry[0] > now: + return cache_entry[1] + + inflight_task = _capability_inflight.get(key) + if inflight_task is None: + inflight_task = asyncio.create_task( + _probe_model_capabilities( + base_url=base_url, + api_key=api_key, + model_id=model_id, + timeout_s=timeout_s, + ) + ) + _capability_inflight[key] = inflight_task + + try: + capabilities = await inflight_task + if cache_ttl_s > 0: + expires_at = time.monotonic() + cache_ttl_s + async with _capability_lock: + _capability_cache[key] = (expires_at, capabilities) + return capabilities + finally: + async with _capability_lock: + current_task = _capability_inflight.get(key) + if current_task is inflight_task: + _capability_inflight.pop(key, None) diff --git a/src/frontend/App.tsx b/src/frontend/App.tsx index e307c24..2611512 100644 --- a/src/frontend/App.tsx +++ b/src/frontend/App.tsx @@ -39,6 +39,8 @@ import { DEFAULT_APP_SETTINGS } from './features/app/appDefaults'; import { getErrorMessage, resolveActiveProviderConfigs, + resolveRoleAvailability, + supportsImageActions, } from './features/app/appSelectors'; const App: React.FC = () => { @@ -74,6 +76,13 @@ const App: React.FC = () => { const { modelConnectionStatus, detectedCapabilities } = useProviderHealth(appSettings); + const roleAvailability = resolveRoleAvailability(appSettings, modelConnectionStatus); + const imageActionsAvailable = supportsImageActions( + appSettings, + detectedCapabilities, + modelConnectionStatus + ); + const [toolCallLoopDialog, setToolCallLoopDialog] = useState<{ count: number; resolver: (choice: 'stop' | 'continue' | 'unlimited') => void; @@ -186,6 +195,8 @@ const App: React.FC = () => { systemPrompt, activeEditingConfig, activeWritingConfig, + isEditingAvailable: roleAvailability.editing, + isWritingAvailable: roleAvailability.writing, updateChapter, setChatMessages, getErrorMessage, @@ -209,6 +220,7 @@ const App: React.FC = () => { story, systemPrompt, activeWritingConfig, + isWritingAvailable: roleAvailability.writing, updateChapter, viewMode, setChatMessages, @@ -255,6 +267,7 @@ const App: React.FC = () => { const { handleSendMessage, handleStopChat, handleRegenerate } = useChatExecution({ systemPrompt, activeChatConfig, + isChatAvailable: roleAvailability.chat, allowWebSearch, currentChapterId, chatMessages, @@ -299,6 +312,7 @@ const App: React.FC = () => { isImagesOpen={isImagesOpen} setIsImagesOpen={setIsImagesOpen} updateStoryImageSettings={updateStoryImageSettings} + imageActionsAvailable={imageActionsAvailable} editorRef={editorRef} isCreateProjectOpen={isCreateProjectOpen} setIsCreateProjectOpen={setIsCreateProjectOpen} @@ -330,7 +344,11 @@ const App: React.FC = () => { isMobileFormatMenuOpen, setIsMobileFormatMenuOpen, }} - aiControls={{ handleAiAction, isAiActionLoading }} + aiControls={{ + handleAiAction, + isAiActionLoading, + isWritingAvailable: roleAvailability.writing, + }} modelControls={{ appSettings, setAppSettings, @@ -364,6 +382,7 @@ const App: React.FC = () => { handleReorderChapters, handleReorderBooks, handleSidebarAiAction, + isEditingAvailable: roleAvailability.editing, handleOpenImages, updateStoryMetadata, }} @@ -384,6 +403,7 @@ const App: React.FC = () => { aiControls: { handleAiAction, isAiActionLoading, + isWritingAvailable: roleAvailability.writing, }, setActiveFormats, showWhitespace, @@ -393,6 +413,7 @@ const App: React.FC = () => { isChatOpen, chatMessages, isChatLoading, + isChatAvailable: roleAvailability.chat, systemPrompt, handleSendMessage, handleStopChat, diff --git a/src/frontend/features/app/appSelectors.test.ts b/src/frontend/features/app/appSelectors.test.ts index d70da72..b401245 100644 --- a/src/frontend/features/app/appSelectors.test.ts +++ b/src/frontend/features/app/appSelectors.test.ts @@ -11,7 +11,12 @@ import { describe, expect, it } from 'vitest'; import { AppSettings } from '../../types'; -import { getErrorMessage, resolveActiveProviderConfigs } from './appSelectors'; +import { + getErrorMessage, + resolveActiveProviderConfigs, + resolveRoleAvailability, + supportsImageActions, +} from './appSelectors'; const appSettings: AppSettings = { providers: [ @@ -64,4 +69,57 @@ describe('appSelectors', () => { it('uses fallback for non-error values', () => { expect(getErrorMessage({ bad: true }, 'fallback')).toBe('fallback'); }); + + it('computes role availability from model connection status', () => { + const availability = resolveRoleAvailability(appSettings, { + a: 'success', + b: 'error', + }); + expect(availability.chat).toBe(true); + expect(availability.writing).toBe(false); + expect(availability.editing).toBe(false); + }); + + it('requires multimodal support for image actions', () => { + const settingsWithVision: AppSettings = { + ...appSettings, + activeChatProviderId: 'a', + providers: [ + { + ...appSettings.providers[0], + isMultimodal: true, + }, + appSettings.providers[1], + ], + }; + + expect( + supportsImageActions( + settingsWithVision, + { + a: { is_multimodal: true, supports_function_calling: true }, + }, + { a: 'success' } + ) + ).toBe(true); + + expect( + supportsImageActions( + { + ...settingsWithVision, + providers: [ + { + ...settingsWithVision.providers[0], + isMultimodal: false, + }, + settingsWithVision.providers[1], + ], + }, + { + a: { is_multimodal: false, supports_function_calling: true }, + }, + { a: 'success' } + ) + ).toBe(false); + }); }); diff --git a/src/frontend/features/app/appSelectors.ts b/src/frontend/features/app/appSelectors.ts index 6858cfa..ef7c1de 100644 --- a/src/frontend/features/app/appSelectors.ts +++ b/src/frontend/features/app/appSelectors.ts @@ -11,6 +11,12 @@ import { AppSettings, LLMConfig } from '../../types'; +type ConnectionStatus = 'idle' | 'success' | 'error' | 'loading'; +type ProviderCapabilities = { + is_multimodal: boolean; + supports_function_calling: boolean; +}; + export function getErrorMessage(error: unknown, fallback: string): string { return error instanceof Error ? error.message : fallback; } @@ -33,3 +39,45 @@ export function resolveActiveProviderConfigs(appSettings: AppSettings): { fallback, }; } + +export function resolveRoleAvailability( + appSettings: AppSettings, + modelConnectionStatus: Record +): { + writing: boolean; + editing: boolean; + chat: boolean; +} { + const byId = new Map( + appSettings.providers.map((provider) => [provider.id, provider]) + ); + const isAvailable = (providerId: string) => { + const provider = byId.get(providerId); + if (!provider) return false; + if (!(provider.modelId || '').trim()) return false; + return modelConnectionStatus[provider.id] === 'success'; + }; + + return { + writing: isAvailable(appSettings.activeWritingProviderId), + editing: isAvailable(appSettings.activeEditingProviderId), + chat: isAvailable(appSettings.activeChatProviderId), + }; +} + +export function supportsImageActions( + appSettings: AppSettings, + detectedCapabilities: Record, + modelConnectionStatus: Record +): boolean { + const activeChatProvider = appSettings.providers.find( + (provider) => provider.id === appSettings.activeChatProviderId + ); + if (!activeChatProvider) return false; + if (modelConnectionStatus[activeChatProvider.id] !== 'success') return false; + + if (activeChatProvider.isMultimodal === true) return true; + if (activeChatProvider.isMultimodal === false) return false; + + return !!detectedCapabilities[activeChatProvider.id]?.is_multimodal; +} diff --git a/src/frontend/features/chapters/ChapterList.tsx b/src/frontend/features/chapters/ChapterList.tsx index 5c39a74..670deb6 100644 --- a/src/frontend/features/chapters/ChapterList.tsx +++ b/src/frontend/features/chapters/ChapterList.tsx @@ -43,6 +43,7 @@ interface ChapterListProps { action: 'write' | 'update' | 'rewrite', onProgress?: (text: string) => void ) => Promise; + isAiAvailable?: boolean; theme?: AppTheme; onOpenImages?: () => void; } @@ -62,6 +63,7 @@ export const ChapterList: React.FC = ({ onReorderChapters, onReorderBooks, onAiAction, + isAiAvailable = true, theme = 'mixed', onOpenImages, }) => { @@ -404,6 +406,11 @@ export const ChapterList: React.FC = ({ onSave={saveMetadata} onClose={() => setEditingMetadata(null)} theme={theme} + aiDisabledReason={ + !isAiAvailable + ? 'Summary AI is unavailable because no working EDITING model is configured.' + : undefined + } onAiGenerate={ onAiAction && editingMetadata ? (action, onProgress) => diff --git a/src/frontend/features/chapters/useChapterSuggestions.ts b/src/frontend/features/chapters/useChapterSuggestions.ts index 7d45067..84a74d8 100644 --- a/src/frontend/features/chapters/useChapterSuggestions.ts +++ b/src/frontend/features/chapters/useChapterSuggestions.ts @@ -22,6 +22,7 @@ type UseChapterSuggestionsParams = { story: StoryState; systemPrompt: string; activeWritingConfig: LLMConfig; + isWritingAvailable: boolean; updateChapter: (id: string, partial: Partial) => Promise; viewMode: ViewMode; setChatMessages: Dispatch>; @@ -34,6 +35,7 @@ export function useChapterSuggestions({ story, systemPrompt, activeWritingConfig, + isWritingAvailable, updateChapter, viewMode, setChatMessages, @@ -58,6 +60,7 @@ export function useChapterSuggestions({ enableSuggestionMode: boolean = true ) => { if (!currentChapter) return; + if (!isWritingAvailable) return; if (isSuggesting) return; const baseContent = contentOverride ?? currentChapter.content; @@ -75,7 +78,18 @@ export function useChapterSuggestions({ storyContext, systemPrompt, activeWritingConfig, - currentChapter.id + currentChapter.id, + { + onSuggestionUpdate: (index, text) => { + if (!text) return; + setContinuations((previous) => { + const next = [...previous]; + if (next[index] === text) return previous; + next[index] = text; + return next; + }); + }, + } ); setContinuations(options); } catch (error: unknown) { diff --git a/src/frontend/features/chat/Chat.tsx b/src/frontend/features/chat/Chat.tsx index 5eeb4c4..76ba58d 100644 --- a/src/frontend/features/chat/Chat.tsx +++ b/src/frontend/features/chat/Chat.tsx @@ -34,6 +34,7 @@ import { ChatComposer } from './components/ChatComposer'; interface ChatProps { messages: ChatMessage[]; isLoading: boolean; + isModelAvailable?: boolean; systemPrompt: string; onSendMessage: (text: string) => void; onStop?: () => void; @@ -58,6 +59,7 @@ interface ChatProps { export const Chat: React.FC = ({ messages, isLoading, + isModelAvailable = true, systemPrompt, onSendMessage, onStop, @@ -78,6 +80,9 @@ export const Chat: React.FC = ({ allowWebSearch, onToggleWebSearch, }) => { + const chatDisabledReason = + 'Chat is unavailable because no working CHAT model is configured.'; + const [input, setInput] = useState(''); const [editingMessageId, setEditingMessageId] = useState(null); const [editContent, setEditContent] = useState(''); @@ -146,7 +151,7 @@ export const Chat: React.FC = ({ const handleSubmit = (e?: React.FormEvent) => { e?.preventDefault(); - if (input.trim() && !isLoading) { + if (input.trim() && !isLoading && isModelAvailable) { onSendMessage(input.trim()); setInput(''); if (textareaRef.current) { @@ -182,7 +187,7 @@ export const Chat: React.FC = ({ }; const lastMessage = messages[messages.length - 1]; - const canRegenerate = !isLoading && lastMessage?.role === 'model'; + const canRegenerate = !isLoading && isModelAvailable && lastMessage?.role === 'model'; return (
= ({ headerBg={headerBg} currentSessionId={currentSessionId} isIncognito={isIncognito} + isDisabled={!isModelAvailable} + disabledReason={chatDisabledReason} showHistory={showHistory} setShowHistory={setShowHistory} showSystemPrompt={showSystemPrompt} @@ -207,6 +214,8 @@ export const Chat: React.FC = ({ = ({ onChange={(e) => setTempSystemPrompt(e.target.value)} className={`w-full h-32 rounded-md p-3 text-sm focus:ring-1 focus:ring-brand-500 focus:outline-none resize-none mb-3 border ${inputBg}`} placeholder="Define the AI's persona and rules..." + disabled={!isModelAvailable} + title={!isModelAvailable ? chatDisabledReason : 'System Instruction'} />
@@ -241,6 +254,8 @@ export const Chat: React.FC = ({ size="sm" variant="primary" onClick={handleSystemPromptSave} + disabled={!isModelAvailable} + title={!isModelAvailable ? chatDisabledReason : 'Update Persona'} > Update Persona @@ -255,6 +270,18 @@ export const Chat: React.FC = ({ isLight ? 'bg-brand-gray-50' : 'bg-brand-gray-950/30' }`} > + {!isModelAvailable && ( +
+ {chatDisabledReason} +
+ )} + {messages.length === 0 && !showSystemPrompt && (
@@ -353,6 +380,7 @@ export const Chat: React.FC = ({ size="sm" variant="secondary" onClick={() => { + if (!isModelAvailable) return; // Extract project name from either raw text or JSON message field let projectName = ''; try { @@ -377,6 +405,12 @@ export const Chat: React.FC = ({ } }} icon={} + disabled={!isModelAvailable} + title={ + !isModelAvailable + ? chatDisabledReason + : 'Switch to New Project' + } > Switch to New Project @@ -447,16 +481,24 @@ export const Chat: React.FC = ({ } opacity-0 group-hover:opacity-100 transition-opacity flex flex-col space-y-1`} > @@ -512,9 +554,14 @@ export const Chat: React.FC = ({ size="sm" variant="secondary" onClick={onRegenerate} + disabled={!isModelAvailable} icon={} className="text-xs py-1 h-7 border-dashed" - title="Regenerate last response (CHAT model)" + title={ + !isModelAvailable + ? chatDisabledReason + : 'Regenerate last response (CHAT model)' + } > Regenerate last response @@ -527,6 +574,8 @@ export const Chat: React.FC = ({ input={input} setInput={setInput} isLoading={isLoading} + isModelAvailable={isModelAvailable} + disabledReason={chatDisabledReason} inputBg={inputBg} onSubmit={handleSubmit} /> diff --git a/src/frontend/features/chat/ModelSelector.tsx b/src/frontend/features/chat/ModelSelector.tsx index 392810c..cb59ea4 100644 --- a/src/frontend/features/chat/ModelSelector.tsx +++ b/src/frontend/features/chat/ModelSelector.tsx @@ -10,7 +10,7 @@ */ import React, { useState, useRef, useEffect } from 'react'; -import { Eye, Wand2, ChevronDown, Check, AlertCircle, Loader2 } from 'lucide-react'; +import { Eye, Wand2, AlertTriangle, Loader2 } from 'lucide-react'; import { LLMConfig, AppTheme } from '../../types'; interface ModelSelectorProps { @@ -116,6 +116,13 @@ export const ModelSelector: React.FC = ({ 'supportsFunctionCalling', 'supports_function_calling' ) && } + {label === 'Writing' && activeOption?.writingWarning && ( + + )}
{getStatusIcon(opt.id)}
diff --git a/src/frontend/features/chat/components/ChatComposer.tsx b/src/frontend/features/chat/components/ChatComposer.tsx index f093509..f9bf6d4 100644 --- a/src/frontend/features/chat/components/ChatComposer.tsx +++ b/src/frontend/features/chat/components/ChatComposer.tsx @@ -17,6 +17,8 @@ type ChatComposerProps = { input: string; setInput: (value: string) => void; isLoading: boolean; + isModelAvailable?: boolean; + disabledReason?: string; inputBg: string; onSubmit: (e?: React.FormEvent) => void; }; @@ -26,9 +28,17 @@ export const ChatComposer: React.FC = ({ input, setInput, isLoading, + isModelAvailable = true, + disabledReason, inputBg, onSubmit, }) => { + const isDisabled = isLoading || !isModelAvailable; + const disabledTitle = !isModelAvailable + ? disabledReason || + 'Chat is unavailable because no working CHAT model is configured.' + : 'Send Message (CHAT model)'; + const handleKeyDown = (e: React.KeyboardEvent) => { if (e.key === 'Enter' && !e.shiftKey) { e.preventDefault(); @@ -45,14 +55,15 @@ export const ChatComposer: React.FC = ({ onChange={(e) => setInput(e.target.value)} onKeyDown={handleKeyDown} placeholder="Type your instruction..." - className={`w-full pl-4 pr-12 py-3 rounded-lg focus:ring-2 focus:ring-brand-500 focus:border-brand-500 transition-all text-sm placeholder-brand-gray-400 border resize-none overflow-y-auto ${inputBg}`} - disabled={isLoading} + className={`w-full pl-4 pr-12 py-3 rounded-lg focus:ring-2 focus:ring-brand-500 focus:border-brand-500 transition-all text-sm placeholder-brand-gray-400 border resize-none overflow-y-auto disabled:cursor-not-allowed ${inputBg}`} + disabled={isDisabled} + title={disabledTitle} /> diff --git a/src/frontend/features/chat/components/ChatHeader.tsx b/src/frontend/features/chat/components/ChatHeader.tsx index 8715aea..4e8f62d 100644 --- a/src/frontend/features/chat/components/ChatHeader.tsx +++ b/src/frontend/features/chat/components/ChatHeader.tsx @@ -17,6 +17,8 @@ type ChatHeaderProps = { headerBg?: string; currentSessionId: string | null; isIncognito: boolean; + isDisabled?: boolean; + disabledReason?: string; showHistory: boolean; setShowHistory: (value: boolean) => void; showSystemPrompt: boolean; @@ -32,6 +34,8 @@ export const ChatHeader: React.FC = ({ headerBg = 'bg-brand-gray-100 dark:bg-brand-gray-900', currentSessionId, isIncognito, + isDisabled = false, + disabledReason, showHistory, setShowHistory, showSystemPrompt, @@ -41,6 +45,10 @@ export const ChatHeader: React.FC = ({ onNewSession, onToggleWebSearch, }) => { + const reason = + disabledReason || + 'Chat is unavailable because no working CHAT model is configured.'; + return (
= ({
diff --git a/src/frontend/features/chat/components/ChatHistoryPanel.tsx b/src/frontend/features/chat/components/ChatHistoryPanel.tsx index 7373c32..c2b8d1a 100644 --- a/src/frontend/features/chat/components/ChatHistoryPanel.tsx +++ b/src/frontend/features/chat/components/ChatHistoryPanel.tsx @@ -16,6 +16,8 @@ import { ChatSession } from '../../../types'; type ChatHistoryPanelProps = { sessions: ChatSession[]; currentSessionId: string | null; + isDisabled?: boolean; + disabledReason?: string; onSelectSession: (id: string) => void; onDeleteSession: (id: string) => void; onDeleteAllSessions?: () => void; @@ -25,11 +27,17 @@ type ChatHistoryPanelProps = { export const ChatHistoryPanel: React.FC = ({ sessions, currentSessionId, + isDisabled = false, + disabledReason, onSelectSession, onDeleteSession, onDeleteAllSessions, onClose, }) => { + const reason = + disabledReason || + 'Chat is unavailable because no working CHAT model is configured.'; + return (
@@ -39,8 +47,13 @@ export const ChatHistoryPanel: React.FC = ({ {sessions.length > 0 && onDeleteAllSessions && ( @@ -68,11 +81,13 @@ export const ChatHistoryPanel: React.FC = ({ currentSessionId === session.id ? 'bg-brand-gray-200 dark:bg-brand-gray-800 text-brand-600 font-medium' : 'hover:bg-brand-gray-200/50 dark:hover:bg-brand-gray-800/50' - }`} + } ${isDisabled ? 'opacity-60 cursor-not-allowed' : ''}`} onClick={() => { + if (isDisabled) return; onSelectSession(session.id); onClose(); }} + title={isDisabled ? reason : session.name} >
@@ -92,11 +107,14 @@ export const ChatHistoryPanel: React.FC = ({ diff --git a/src/frontend/features/chat/useChatExecution.ts b/src/frontend/features/chat/useChatExecution.ts index e456575..8395ca1 100644 --- a/src/frontend/features/chat/useChatExecution.ts +++ b/src/frontend/features/chat/useChatExecution.ts @@ -33,6 +33,7 @@ type ToolLoopChoice = 'stop' | 'continue' | 'unlimited'; type UseChatExecutionParams = { systemPrompt: string; activeChatConfig: LLMConfig; + isChatAvailable: boolean; allowWebSearch: boolean; currentChapterId: string | null; chatMessages: ChatMessage[]; @@ -47,6 +48,7 @@ type UseChatExecutionParams = { export function useChatExecution({ systemPrompt, activeChatConfig, + isChatAvailable, allowWebSearch, currentChapterId, chatMessages, @@ -241,6 +243,7 @@ export function useChatExecution({ }; const handleSendMessage = async (text: string) => { + if (!isChatAvailable) return; const userMsgId = uuidv4(); const newMessage: ChatMessage = { id: userMsgId, role: 'user', text }; const historyBefore = [...chatMessages]; @@ -254,6 +257,7 @@ export function useChatExecution({ }; const handleRegenerate = async () => { + if (!isChatAvailable) return; const lastMessageIndex = chatMessages.length - 1; if (lastMessageIndex < 0) return; diff --git a/src/frontend/features/editor/Editor.tsx b/src/frontend/features/editor/Editor.tsx index 5fde093..7d43872 100644 --- a/src/frontend/features/editor/Editor.tsx +++ b/src/frontend/features/editor/Editor.tsx @@ -64,6 +64,7 @@ interface EditorProps { action: 'update' | 'rewrite' | 'extend' ) => void; isAiLoading: boolean; + isWritingAvailable?: boolean; }; onContextChange?: (formats: string[]) => void; } @@ -97,6 +98,7 @@ export const Editor = React.forwardRef( const textareaRef = useRef(null); const wysiwygRef = useRef(null); const fileInputRef = useRef(null); + const scrollContainerRef = useRef(null); const { continuations, @@ -106,7 +108,9 @@ export const Editor = React.forwardRef( isSuggestionMode, onKeyboardSuggestionAction, } = suggestionControls; - const { onAiAction, isAiLoading } = aiControls; + const { onAiAction, isAiLoading, isWritingAvailable = true } = aiControls; + const writingUnavailableReason = + 'This action is unavailable because no working WRITING model is configured.'; const [isDragging, setIsDragging] = useState(false); @@ -498,6 +502,31 @@ export const Editor = React.forwardRef( settings.theme === 'light' ? 'bg-brand-gray-50 border-t border-brand-gray-200' : 'bg-brand-gray-900 border-t border-brand-gray-800'; + const hasContinuationOptions = continuations.some( + (option) => option && option.trim().length > 0 + ); + + const scrollMainContentToBottom = useCallback(() => { + const container = scrollContainerRef.current; + if (!container) return; + container.scrollTop = container.scrollHeight; + }, []); + + useEffect(() => { + if (!isAiLoading && !isSuggesting && !hasContinuationOptions) return; + + const raf = window.requestAnimationFrame(() => { + scrollMainContentToBottom(); + }); + return () => window.cancelAnimationFrame(raf); + }, [ + chapter.content, + continuations, + isAiLoading, + isSuggesting, + hasContinuationOptions, + scrollMainContentToBottom, + ]); return (
( variant="ghost" className="text-xs h-7" onClick={() => onAiAction('chapter', 'extend')} - disabled={isAiLoading} + disabled={isAiLoading || !isWritingAvailable} icon={} - title="Extend Chapter (WRITING model)" + title={ + !isWritingAvailable + ? writingUnavailableReason + : 'Extend Chapter (WRITING model)' + } > Extend @@ -542,9 +575,13 @@ export const Editor = React.forwardRef( variant="ghost" className="text-xs h-7" onClick={() => onAiAction('chapter', 'rewrite')} - disabled={isAiLoading} + disabled={isAiLoading || !isWritingAvailable} icon={} - title="Rewrite Chapter (WRITING model)" + title={ + !isWritingAvailable + ? writingUnavailableReason + : 'Rewrite Chapter (WRITING model)' + } > Rewrite @@ -555,6 +592,7 @@ export const Editor = React.forwardRef( {/* Main Scrollable Content Area */}
(
- {continuations.length > 0 ? ( + {hasContinuationOptions ? (
@@ -666,40 +704,49 @@ export const Editor = React.forwardRef(
- {continuations.map((option, idx) => ( -
onAcceptContinuation(option)} - className={`group relative p-5 rounded-lg border cursor-pointer transition-all hover:shadow-lg hover:-translate-y-0.5 ${ - settings.theme === 'light' - ? 'bg-brand-gray-50 border-brand-gray-200 hover:bg-brand-gray-50 hover:border-brand-300' - : 'bg-brand-gray-800 border-brand-gray-700 hover:bg-brand-gray-750 hover:border-brand-500/50' - }`} - > + {continuations.map((option, idx) => { + if (!option || option.trim().length === 0) { + return null; + } + return (
onAcceptContinuation(option)} + className={`group relative p-5 rounded-lg border cursor-pointer transition-all hover:shadow-lg hover:-translate-y-0.5 ${ settings.theme === 'light' - ? 'text-brand-gray-800' - : 'text-brand-gray-300 group-hover:text-brand-gray-200' + ? 'bg-brand-gray-50 border-brand-gray-200 hover:bg-brand-gray-50 hover:border-brand-300' + : 'bg-brand-gray-800 border-brand-gray-700 hover:bg-brand-gray-750 hover:border-brand-500/50' }`} > - {option} +
+ {option} +
-
- ))} + ); + })}
) : (
@@ -476,9 +482,13 @@ export const HeaderCenterControls: React.FC = ({ variant="ghost" className="text-xs h-6" onClick={() => handleAiAction('chapter', 'rewrite')} - disabled={isAiActionLoading} + disabled={isAiActionLoading || !isWritingAvailable} icon={} - title="Rewrite Chapter (WRITING model)" + title={ + !isWritingAvailable + ? writingUnavailableReason + : 'Rewrite Chapter (WRITING model)' + } > Rewrite diff --git a/src/frontend/features/layout/layoutControlTypes.ts b/src/frontend/features/layout/layoutControlTypes.ts index ec2f5bf..0b633ee 100644 --- a/src/frontend/features/layout/layoutControlTypes.ts +++ b/src/frontend/features/layout/layoutControlTypes.ts @@ -66,6 +66,7 @@ export type HeaderAiControls = { action: 'update' | 'rewrite' | 'extend' ) => Promise; isAiActionLoading: boolean; + isWritingAvailable: boolean; }; export type HeaderModelControls = { @@ -116,11 +117,12 @@ export type MainSidebarControls = { handleReorderChapters: (chapterIds: number[], bookId?: string) => Promise; handleReorderBooks: (bookIds: string[]) => Promise; handleSidebarAiAction: ( - type: 'chapter' | 'book', + type: 'chapter' | 'book' | 'story', id: string, action: 'write' | 'update' | 'rewrite', onProgress?: (text: string) => void ) => Promise; + isEditingAvailable: boolean; handleOpenImages: () => void; updateStoryMetadata: ( updates: Partial<{ @@ -156,6 +158,7 @@ export type MainEditorAiControls = { action: 'update' | 'rewrite' | 'extend' ) => Promise; isAiActionLoading: boolean; + isWritingAvailable: boolean; }; export type MainEditorControls = { @@ -175,6 +178,7 @@ export type MainChatControls = { isChatOpen: boolean; chatMessages: ChatMessage[]; isChatLoading: boolean; + isChatAvailable: boolean; systemPrompt: string; handleSendMessage: (text: string) => Promise; handleStopChat: () => void; diff --git a/src/frontend/features/projects/ProjectImages.tsx b/src/frontend/features/projects/ProjectImages.tsx index ed0da9d..e370999 100644 --- a/src/frontend/features/projects/ProjectImages.tsx +++ b/src/frontend/features/projects/ProjectImages.tsx @@ -46,6 +46,7 @@ interface ProjectImagesProps { onClose: () => void; theme: AppTheme; settings: AppSettings; + imageActionsAvailable?: boolean; prompts?: { system_messages: Record; user_prompts: Record; @@ -61,6 +62,7 @@ export const ProjectImages: React.FC = ({ onClose, theme = 'mixed', settings, + imageActionsAvailable = true, prompts, imageStyle = '', imageAdditionalInfo = '', @@ -162,6 +164,7 @@ export const ProjectImages: React.FC = ({ }; const handleGenerateDescription = async (img: ImageEntry) => { + if (!imageActionsAvailable) return; if (generating) return; setGenerating(img.filename); setError(null); @@ -199,6 +202,7 @@ export const ProjectImages: React.FC = ({ }; const handleCreatePrompt = async (img: ImageEntry) => { + if (!imageActionsAvailable) return; if (!img.description) return; setPromptPopup({ isOpen: true, content: '', loading: true }); @@ -227,6 +231,7 @@ export const ProjectImages: React.FC = ({ }; const handleGenerateAllPrompts = async () => { + if (!imageActionsAvailable) return; const placeholders = images.filter((i) => i.is_placeholder); if (placeholders.length === 0) return; @@ -470,6 +475,7 @@ export const ProjectImages: React.FC = ({
+
+ +
+ +
+ {getPresetById(activeProvider.presetId)?.description && ( +

+ {getPresetById(activeProvider.presetId)?.description} +

+ )} +
)}
+ {suggestedPresetByProvider[activeProvider.id] && ( +
+ + Suggested preset:{' '} + { + getPresetById(suggestedPresetByProvider[activeProvider.id]) + ?.name + } + + +
+ )} {/* Model availability indicator */}
= ({
{renderSlider('Temperature', 'temperature', 0, 2, 0.1)} {renderSlider('Top P', 'topP', 0, 1, 0.05)} + {renderNumberInput('Max Tokens', 'maxTokens')} + {renderNumberInput('Seed', 'seed')} + {renderNumberInput('Presence Penalty', 'presencePenalty')} + {renderNumberInput('Frequency Penalty', 'frequencyPenalty')} + {renderNumberInput('Top K', 'topK')} + {renderNumberInput('Min P', 'minP')} +
+
+
+ +