Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 42 additions & 3 deletions src/client/app/content/config/tabs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from streamlit import session_state as state

from client.app.core import helpers
from client.app.core.api import api_delete, api_get, api_post, api_put
from client.app.core.api import api_delete, api_get, api_post, api_post_stream, api_put

LOGGER = logging.getLogger("client.content.config.tabs.models")

Expand Down Expand Up @@ -324,20 +324,52 @@ def edit_model(
st.rerun()


#####################################################
# Pull Dialog
#####################################################
@st.dialog("Pull Ollama Model")
def pull_model_dialog(provider: str, model_id: str) -> None:
"""Stream Ollama model pull progress."""
st.write(f"Pulling **{provider}/{model_id}** from Ollama registry...")
quoted_id = urllib.parse.quote(model_id, safe="")
status_text = st.empty()
progress_bar = st.empty()

try:
for event in api_post_stream(f"models/pull/{provider}/{quoted_id}"):
if "error" in event:
st.error(f"Pull failed: {event['error']}")
return
status = event.get("status", "")
completed = event.get("completed", 0)
total = event.get("total", 0)
if total > 0:
progress_bar.progress(completed / total, text=status)
else:
status_text.text(status)
except httpx.HTTPStatusError as exc:
st.error(f"Pull failed: {helpers.extract_error_detail(exc)}")
return

helpers.refresh_settings()
st.success(f"Model **{model_id}** pulled successfully. You can now enable it.")


#####################################################
# Table Display
#####################################################
def render_model_rows(model_type: str) -> None:
"""Render rows of the models."""
models = [m for m in state["settings"]["model_configs"] if m.get("type") == model_type]
data_col_widths = [0.06, 0.44, 0.38, 0.12]
data_col_widths = [0.06, 0.40, 0.34, 0.10, 0.10]

table_col_format = st.columns(data_col_widths, vertical_alignment="center")
col1, col2, col3, col4 = table_col_format
col1, col2, col3, col4, col5 = table_col_format
col1.markdown("​", unsafe_allow_html=True, width="content")
col2.markdown("**<u>Model</u>**", unsafe_allow_html=True)
col3.markdown("**<u>Provider URL</u>**", unsafe_allow_html=True)
col4.markdown("&#x200B;")
col5.markdown("&#x200B;")

for model in models:
model_id = model["id"]
Expand Down Expand Up @@ -375,6 +407,13 @@ def render_model_rows(model_type: str) -> None:
"model_provider": model_provider,
},
)
if model_provider == "ollama" and not model.get("usable", False):
col5.button(
"Pull",
on_click=pull_model_dialog,
key=f"runtime_{model_type}_{model_provider}_{model_id}_pull",
kwargs={"provider": model_provider, "model_id": model_id},
)

if st.button(label="Add", type="primary", key=f"add_{model_type}_model"):
edit_model(model_type=model_type, action="add")
Expand Down
25 changes: 25 additions & 0 deletions src/client/app/core/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
# spell-checker:ignore apiserver pypath

import atexit
import json
import logging
import os
import secrets
import subprocess
import sys
import time
from collections.abc import Generator
from io import TextIOWrapper
from pathlib import Path
from typing import Any
Expand Down Expand Up @@ -206,6 +208,29 @@ def api_post(
return resp.json()


def api_post_stream(
path: str,
json_body: dict | None = None,
timeout: int = 600,
api_prefix: str = "/v1",
) -> Generator[dict, None, None]:
"""Streaming POST request to the API server. Yields parsed NDJSON dicts."""
url = f"{_base_url(api_prefix)}/{path.lstrip('/')}"
with (
httpx.Client(headers=_headers(), timeout=timeout) as client,
client.stream("POST", url, json=json_body) as resp,
):
resp.raise_for_status()
for raw_line in resp.iter_lines():
stripped = raw_line.strip()
if not stripped:
continue
try:
yield json.loads(stripped)
except json.JSONDecodeError:
continue


def api_put(
path: str,
json: dict | None = None,
Expand Down
4 changes: 2 additions & 2 deletions src/client/tests/content/config/tabs/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,7 +998,7 @@ def test_renders_rows_only_for_matching_type(self, make_model_state, mock_st):
]
state = make_model_state(model_configs=configs)

cols = [MagicMock(), MagicMock(), MagicMock(), MagicMock()]
cols = [MagicMock(), MagicMock(), MagicMock(), MagicMock(), MagicMock()]
mock_st.columns.side_effect = lambda widths, **kw: cols
mock_st.button.return_value = False

Expand All @@ -1020,7 +1020,7 @@ def test_edit_button_kwargs(self, make_model_state, mock_st):
]
state = make_model_state(model_configs=configs)

cols = [MagicMock(), MagicMock(), MagicMock(), MagicMock()]
cols = [MagicMock(), MagicMock(), MagicMock(), MagicMock(), MagicMock()]
mock_st.columns.side_effect = lambda widths, **kw: cols
mock_st.button.return_value = False

Expand Down
35 changes: 35 additions & 0 deletions src/server/app/api/v1/endpoints/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@
"""
# spell-checker:ignore litellm ollama rerank

import json
from typing import Optional

import litellm
from fastapi import APIRouter, HTTPException, Query
from fastapi.responses import StreamingResponse

from server.app.core.settings import _settings_lock, settings
from server.app.database.settings import persist_settings
from server.app.models.connectivity import check_single_model
from server.app.models.litellm_utils import find_model
from server.app.models.ollama import pull_ollama_model
from server.app.models.schemas import ModelConfig, ModelSensitive, ModelUpdate, SupportedProviderIds

litellm.suppress_debug_info = True
Expand Down Expand Up @@ -100,6 +103,38 @@ async def models_supported(
return _get_supported(model_provider=model_provider, model_type=model_type)


# --- Pull endpoint (must be before /{provider}/{model_id:path}) ---


@auth.post("/pull/{provider}/{model_id:path}")
async def pull_model(provider: str, model_id: str):
"""Pull an Ollama model and stream progress as NDJSON."""
if provider.casefold() != "ollama":
raise HTTPException(status_code=400, detail="Pull is only supported for Ollama models")

cfg = _find_model(provider, model_id)
if cfg is None:
raise HTTPException(status_code=404, detail=f"Model config not found: {provider}/{model_id}")
if not cfg.api_base:
raise HTTPException(status_code=400, detail=f"Model {provider}/{model_id} has no API base URL configured")
api_base: str = cfg.api_base

async def _stream():
error_occurred = False
async for event in pull_ollama_model(api_base, model_id):
if "error" in event:
error_occurred = True
yield json.dumps(event) + "\n"
if not error_occurred:
await check_single_model(cfg)
if not await persist_settings():
yield json.dumps({"error": _PERSIST_FAIL}) + "\n"
else:
yield json.dumps({"status": "success"}) + "\n"

return StreamingResponse(_stream(), media_type="application/x-ndjson")


@auth.get("/{provider}/{model_id:path}", response_model=ModelConfig, response_model_exclude_unset=True)
async def get_model(provider: str, model_id: str, include_sensitive: bool = Query(default=False)):
"""Return a single model configuration by provider and id (case-insensitive)."""
Expand Down
28 changes: 28 additions & 0 deletions src/server/app/models/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
"""
# spell-checker: ignore ollama nomic

import json
import logging
import os
from collections.abc import AsyncGenerator

import httpx

Expand Down Expand Up @@ -129,3 +131,29 @@ async def load_ollama_models() -> None:
LOGGER.info("Removed %d Ollama model(s) no longer pulled", len(removed))

LOGGER.info("Discovered %d Ollama model(s) at %s", len(models), api_base)


async def pull_ollama_model(api_base: str, model_name: str) -> AsyncGenerator[dict, None]:
"""Stream pull progress from an Ollama server as dicts.

Each yielded dict mirrors the NDJSON lines from Ollama's ``/api/pull``
endpoint (keys like ``status``, ``completed``, ``total``, ``digest``).
On error an ``{"error": "..."}`` dict is yielded.
"""
url = f"{api_base.rstrip('/')}/api/pull"
pull_read_timeout = 120.0 # Longer timeout for pull — gaps between progress lines during layer downloads
try:
async with httpx.AsyncClient(timeout=httpx.Timeout(pull_read_timeout, connect=CONNECT_TIMEOUT)) as client: # noqa: SIM117
async with client.stream("POST", url, json={"name": model_name}) as resp:
resp.raise_for_status()
async for raw_line in resp.aiter_lines():
stripped = raw_line.strip()
if not stripped:
continue
try:
yield json.loads(stripped)
except json.JSONDecodeError:
continue
except httpx.HTTPError as exc:
LOGGER.warning("Ollama pull failed for '%s' at %s: %s", model_name, api_base, exc)
yield {"error": str(exc)}
Loading