Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 80 additions & 12 deletions src/ucode/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,54 @@ def _debug_enabled() -> bool:
return os.environ.get("UCODE_DEBUG") == "1"


def _preferred_model_prefix() -> str | None:
"""Return the user-configured model prefix preference, or None.

Set ``UCODE_MODEL_PREFERRED_PREFIX`` to a prefix string (e.g. ``companyname-``) to
promote any discovered model whose name starts with that prefix to the top
of each family's candidate list. The normal newest-first ordering still
applies within the preferred and non-preferred groups, so the best version
of the preferred prefix always wins over an older version of the same prefix.
"""
return os.environ.get("UCODE_MODEL_PREFERRED_PREFIX") or None
Comment thread
tplass-ias marked this conversation as resolved.


def _model_bare_name(model_id: str) -> str:
"""Return the bare model name, stripping any catalog prefix (e.g. ``system.ai.``)."""
return model_id.rsplit(".", 1)[-1]


def _sort_candidates_by_prefix(
candidates: list[str],
preferred_prefix: str | None,
*,
sort_key=None,
) -> list[str]:
"""Sort model id candidates newest-first, with preferred_prefix promoted.

``sort_key`` overrides the default reverse-alpha ordering within each group
(e.g. pass ``model_version_sort_key`` for Gemini). Without a preferred
prefix this is equivalent to ``sorted(candidates, key=sort_key)`` (or
``sorted(..., reverse=True)`` for the default key). With one, preferred
items sort before others; both groups use the same criterion.

Prefix matching is anchored to the start of the bare model name so that
``system.ai.companyname-claude-haiku-4-5`` matches prefix ``companyname-``
without false-positives from mid-string occurrences.
"""
if not preferred_prefix:
return sorted(candidates, key=sort_key) if sort_key else sorted(candidates, reverse=True)
preferred = [m for m in candidates if _model_bare_name(m).startswith(preferred_prefix)]
others = [m for m in candidates if not _model_bare_name(m).startswith(preferred_prefix)]
if sort_key:
return sorted(preferred, key=sort_key) + sorted(others, key=sort_key)
return sorted(
candidates,
key=lambda m: (_model_bare_name(m).startswith(preferred_prefix), m),
reverse=True,
)
Comment on lines +89 to +97


_DEBUG_LOGGER: logging.Logger | None = None


Expand Down Expand Up @@ -1174,17 +1222,20 @@ def discover_model_services(
if not ids:
return {}, [], [], reason

preferred_prefix = _preferred_model_prefix()
claude_models: dict[str, str] = {}
for family in ("opus", "sonnet", "haiku"):
candidates = sorted(
candidates = _sort_candidates_by_prefix(
[m for m in ids if f"claude-{family}-" in m],
reverse=True,
preferred_prefix,
)
if candidates:
claude_models[family] = candidates[0]

codex_models = [m for m in ids if "gpt-" in m]
gemini_models = sorted([m for m in ids if "gemini-" in m], key=model_version_sort_key)
codex_models = _sort_candidates_by_prefix([m for m in ids if "gpt-" in m], preferred_prefix)
gemini_models = _sort_candidates_by_prefix(
[m for m in ids if "gemini-" in m], preferred_prefix, sort_key=model_version_sort_key
)

if not (claude_models or codex_models or gemini_models):
sample = ", ".join(ids[:5])
Expand Down Expand Up @@ -1270,12 +1321,21 @@ def discover_claude_models(workspace: str, token: str) -> tuple[dict[str, str],
if isinstance(m.get("id"), str) and not m["id"].endswith("-anthropic")
]

preferred_prefix = _preferred_model_prefix()
result: dict[str, str] = {}
for family, key in [("opus", "opus"), ("sonnet", "sonnet"), ("haiku", "haiku")]:
candidates = sorted(
[m for m in raw_ids if f"databricks-claude-{family}-" in m],
reverse=True,
)
# Accept both the standard databricks-claude-* prefix and any custom
# preferred prefix so that e.g. companyname-claude-haiku-4-5 is discoverable.
family_ids = [
m
for m in raw_ids
if f"claude-{family}-" in m
and (
f"databricks-claude-{family}-" in m
or (preferred_prefix and _model_bare_name(m).startswith(preferred_prefix))
)
]
Comment thread
tplass-ias marked this conversation as resolved.
Comment on lines +1327 to +1337
candidates = _sort_candidates_by_prefix(family_ids, preferred_prefix)
if candidates:
result[key] = candidates[0]
if result:
Expand Down Expand Up @@ -1328,12 +1388,14 @@ def discover_endpoints_with_api_type(
api_type: str,
*,
sort_key=None,
preferred_prefix: str | None = None,
) -> tuple[list[str], str | None]:
"""List endpoint names whose served_entities expose api_type with v2 support.

Returns (endpoints, reason). reason is None on success; otherwise it
describes why the list is empty. `sort_key` overrides the default
alphabetical ordering of the returned names.
newest-first (reverse-alphabetical) ordering of the returned names. `preferred_prefix` promotes
matching endpoints to the front of the list (see ``_preferred_model_prefix``).
Comment thread
tplass-ias marked this conversation as resolved.
Comment on lines 1396 to +1398
Comment on lines 1396 to +1398
"""
hostname = workspace_hostname(workspace)
payload, reason = _http_get_json(
Expand Down Expand Up @@ -1361,7 +1423,7 @@ def discover_endpoints_with_api_type(
if api_type in api_types:
out.append(name)
if out:
return sorted(out, key=sort_key), None
return _sort_candidates_by_prefix(out, preferred_prefix, sort_key=sort_key), None
if not endpoints:
return [], "foundation-models listing returned no endpoints"
if saw_endpoint_without_v2:
Expand All @@ -1382,12 +1444,18 @@ def discover_gemini_models(workspace: str, token: str) -> tuple[list[str], str |
# Order newest model version first so `default_model()` (which picks the
# first entry) launches e.g. gemini-3.5-flash rather than gemini-2.5-flash.
return discover_endpoints_with_api_type(
workspace, token, "gemini/v1/generateContent", sort_key=model_version_sort_key
workspace,
token,
"gemini/v1/generateContent",
sort_key=model_version_sort_key,
preferred_prefix=_preferred_model_prefix(),
)


def discover_codex_models(workspace: str, token: str) -> tuple[list[str], str | None]:
return discover_endpoints_with_api_type(workspace, token, "openai/v1/responses")
return discover_endpoints_with_api_type(
workspace, token, "openai/v1/responses", preferred_prefix=_preferred_model_prefix()
)


def fetch_gemini_models(workspace: str, token: str) -> list[str]:
Expand Down
77 changes: 73 additions & 4 deletions tests/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,21 @@ def test_selects_opus_4_8_when_advertised(self, monkeypatch):
assert reason is None
assert models["opus"] == "databricks-claude-opus-4-8"

def test_preferred_prefix_promotes_matching_model(self, monkeypatch):
payload = {
"data": [
{"id": "databricks-claude-haiku-4-5"},
{"id": "companyname-claude-haiku-4-5"},
]
}
monkeypatch.setattr(db_mod, "_http_get_json", lambda url, token: (payload, None))
monkeypatch.setenv("UCODE_MODEL_PREFERRED_PREFIX", "companyname-")

models, reason = db_mod.discover_claude_models(WS, "token")

assert reason is None
assert models["haiku"] == "companyname-claude-haiku-4-5"


def _model_service(model_id: str) -> dict:
"""A model-services entry whose `name` strips to `model_id`."""
Expand Down Expand Up @@ -236,6 +251,23 @@ def test_ignores_non_system_ai_schemas(self, monkeypatch):
assert claude == {} # temp.erni.claude-* must not be bucketed
assert gemini == []

def test_preferred_prefix_promotes_matching_claude_model(self, monkeypatch):
payload = {
"model_services": [
_model_service("system.ai.databricks-claude-haiku-4-5"),
_model_service("system.ai.companyname-claude-haiku-4-5"),
]
}
monkeypatch.setattr(
db_mod, "_http_get_json", lambda url, token, timeout=10: (payload, None)
)
monkeypatch.setenv("UCODE_MODEL_PREFERRED_PREFIX", "companyname-")

claude, _, _, reason = db_mod.discover_model_services(WS, "token")

assert reason is None
assert claude["haiku"] == "system.ai.companyname-claude-haiku-4-5"

def test_requests_bounded_page_size(self, monkeypatch):
# The endpoint 499s without a bounded page_size, so every request must
# carry one.
Expand Down Expand Up @@ -445,9 +477,7 @@ def test_returns_newest_flash_first(self, monkeypatch):
assert reason is None
assert models[0] == "databricks-gemini-3-5-flash"

def test_codex_discovery_keeps_alphabetical_order(self, monkeypatch):
# Codex passes no sort_key, so ordering must stay the plain alphabetical
# default — guarding against the gemini change leaking across tools.
def test_codex_discovery_orders_newest_first(self, monkeypatch):
payload = {
"endpoints": [
{
Expand All @@ -471,7 +501,46 @@ def test_codex_discovery_keeps_alphabetical_order(self, monkeypatch):
models, reason = db_mod.discover_codex_models(WS, "token")

assert reason is None
assert models == ["databricks-gpt-4-1", "databricks-gpt-5-2-codex"]
assert models == ["databricks-gpt-5-2-codex", "databricks-gpt-4-1"]

def test_preferred_prefix_promotes_matching_gemini_model(self, monkeypatch):
payload = _foundation_models_payload(
["databricks-gemini-3-5-flash", "companyname-gemini-3-5-flash"]
)
monkeypatch.setattr(db_mod, "_http_get_json", lambda url, token: (payload, None))
monkeypatch.setenv("UCODE_MODEL_PREFERRED_PREFIX", "companyname-")

models, reason = db_mod.discover_gemini_models(WS, "token")

assert reason is None
assert models[0] == "companyname-gemini-3-5-flash"

def test_preferred_prefix_promotes_matching_codex_model(self, monkeypatch):
payload = {
"endpoints": [
{
"name": name,
"config": {
"served_entities": [
{
"foundation_model": {
"ai_gateway_v2_supported": True,
"api_types": ["openai/v1/responses"],
}
}
]
},
}
for name in ["databricks-gpt-5", "companyname-gpt-5"]
]
}
monkeypatch.setattr(db_mod, "_http_get_json", lambda url, token: (payload, None))
monkeypatch.setenv("UCODE_MODEL_PREFERRED_PREFIX", "companyname-")

models, reason = db_mod.discover_codex_models(WS, "token")

assert reason is None
assert models[0] == "companyname-gpt-5"


class TestResolvePatToken:
Expand Down