diff --git a/src/ucode/databricks.py b/src/ucode/databricks.py index 574d906..9e856fd 100644 --- a/src/ucode/databricks.py +++ b/src/ucode/databricks.py @@ -49,6 +49,54 @@ def _debug_enabled() -> bool: return os.environ.get("UCODE_DEBUG") == "1" +def _preferred_model_prefix() -> str | None: + """Return the user-configured model prefix preference, or None. + + Set ``UCODE_MODEL_PREFERRED_PREFIX`` to a prefix string (e.g. ``companyname-``) to + promote any discovered model whose name starts with that prefix to the top + of each family's candidate list. The normal newest-first ordering still + applies within the preferred and non-preferred groups, so the best version + of the preferred prefix always wins over an older version of the same prefix. + """ + return os.environ.get("UCODE_MODEL_PREFERRED_PREFIX") or None + + +def _model_bare_name(model_id: str) -> str: + """Return the bare model name, stripping any catalog prefix (e.g. ``system.ai.``).""" + return model_id.rsplit(".", 1)[-1] + + +def _sort_candidates_by_prefix( + candidates: list[str], + preferred_prefix: str | None, + *, + sort_key=None, +) -> list[str]: + """Sort model id candidates newest-first, with preferred_prefix promoted. + + ``sort_key`` overrides the default reverse-alpha ordering within each group + (e.g. pass ``model_version_sort_key`` for Gemini). Without a preferred + prefix this is equivalent to ``sorted(candidates, key=sort_key)`` (or + ``sorted(..., reverse=True)`` for the default key). With one, preferred + items sort before others; both groups use the same criterion. + + Prefix matching is anchored to the start of the bare model name so that + ``system.ai.companyname-claude-haiku-4-5`` matches prefix ``companyname-`` + without false-positives from mid-string occurrences. + """ + if not preferred_prefix: + return sorted(candidates, key=sort_key) if sort_key else sorted(candidates, reverse=True) + preferred = [m for m in candidates if _model_bare_name(m).startswith(preferred_prefix)] + others = [m for m in candidates if not _model_bare_name(m).startswith(preferred_prefix)] + if sort_key: + return sorted(preferred, key=sort_key) + sorted(others, key=sort_key) + return sorted( + candidates, + key=lambda m: (_model_bare_name(m).startswith(preferred_prefix), m), + reverse=True, + ) + + _DEBUG_LOGGER: logging.Logger | None = None @@ -1174,17 +1222,20 @@ def discover_model_services( if not ids: return {}, [], [], reason + preferred_prefix = _preferred_model_prefix() claude_models: dict[str, str] = {} for family in ("opus", "sonnet", "haiku"): - candidates = sorted( + candidates = _sort_candidates_by_prefix( [m for m in ids if f"claude-{family}-" in m], - reverse=True, + preferred_prefix, ) if candidates: claude_models[family] = candidates[0] - codex_models = [m for m in ids if "gpt-" in m] - gemini_models = sorted([m for m in ids if "gemini-" in m], key=model_version_sort_key) + codex_models = _sort_candidates_by_prefix([m for m in ids if "gpt-" in m], preferred_prefix) + gemini_models = _sort_candidates_by_prefix( + [m for m in ids if "gemini-" in m], preferred_prefix, sort_key=model_version_sort_key + ) if not (claude_models or codex_models or gemini_models): sample = ", ".join(ids[:5]) @@ -1270,12 +1321,21 @@ def discover_claude_models(workspace: str, token: str) -> tuple[dict[str, str], if isinstance(m.get("id"), str) and not m["id"].endswith("-anthropic") ] + preferred_prefix = _preferred_model_prefix() result: dict[str, str] = {} for family, key in [("opus", "opus"), ("sonnet", "sonnet"), ("haiku", "haiku")]: - candidates = sorted( - [m for m in raw_ids if f"databricks-claude-{family}-" in m], - reverse=True, - ) + # Accept both the standard databricks-claude-* prefix and any custom + # preferred prefix so that e.g. companyname-claude-haiku-4-5 is discoverable. + family_ids = [ + m + for m in raw_ids + if f"claude-{family}-" in m + and ( + f"databricks-claude-{family}-" in m + or (preferred_prefix and _model_bare_name(m).startswith(preferred_prefix)) + ) + ] + candidates = _sort_candidates_by_prefix(family_ids, preferred_prefix) if candidates: result[key] = candidates[0] if result: @@ -1328,12 +1388,14 @@ def discover_endpoints_with_api_type( api_type: str, *, sort_key=None, + preferred_prefix: str | None = None, ) -> tuple[list[str], str | None]: """List endpoint names whose served_entities expose api_type with v2 support. Returns (endpoints, reason). reason is None on success; otherwise it describes why the list is empty. `sort_key` overrides the default - alphabetical ordering of the returned names. + newest-first (reverse-alphabetical) ordering of the returned names. `preferred_prefix` promotes + matching endpoints to the front of the list (see ``_preferred_model_prefix``). """ hostname = workspace_hostname(workspace) payload, reason = _http_get_json( @@ -1361,7 +1423,7 @@ def discover_endpoints_with_api_type( if api_type in api_types: out.append(name) if out: - return sorted(out, key=sort_key), None + return _sort_candidates_by_prefix(out, preferred_prefix, sort_key=sort_key), None if not endpoints: return [], "foundation-models listing returned no endpoints" if saw_endpoint_without_v2: @@ -1382,12 +1444,18 @@ def discover_gemini_models(workspace: str, token: str) -> tuple[list[str], str | # Order newest model version first so `default_model()` (which picks the # first entry) launches e.g. gemini-3.5-flash rather than gemini-2.5-flash. return discover_endpoints_with_api_type( - workspace, token, "gemini/v1/generateContent", sort_key=model_version_sort_key + workspace, + token, + "gemini/v1/generateContent", + sort_key=model_version_sort_key, + preferred_prefix=_preferred_model_prefix(), ) def discover_codex_models(workspace: str, token: str) -> tuple[list[str], str | None]: - return discover_endpoints_with_api_type(workspace, token, "openai/v1/responses") + return discover_endpoints_with_api_type( + workspace, token, "openai/v1/responses", preferred_prefix=_preferred_model_prefix() + ) def fetch_gemini_models(workspace: str, token: str) -> list[str]: diff --git a/tests/test_databricks.py b/tests/test_databricks.py index 65f05d8..0119176 100644 --- a/tests/test_databricks.py +++ b/tests/test_databricks.py @@ -131,6 +131,21 @@ def test_selects_opus_4_8_when_advertised(self, monkeypatch): assert reason is None assert models["opus"] == "databricks-claude-opus-4-8" + def test_preferred_prefix_promotes_matching_model(self, monkeypatch): + payload = { + "data": [ + {"id": "databricks-claude-haiku-4-5"}, + {"id": "companyname-claude-haiku-4-5"}, + ] + } + monkeypatch.setattr(db_mod, "_http_get_json", lambda url, token: (payload, None)) + monkeypatch.setenv("UCODE_MODEL_PREFERRED_PREFIX", "companyname-") + + models, reason = db_mod.discover_claude_models(WS, "token") + + assert reason is None + assert models["haiku"] == "companyname-claude-haiku-4-5" + def _model_service(model_id: str) -> dict: """A model-services entry whose `name` strips to `model_id`.""" @@ -236,6 +251,23 @@ def test_ignores_non_system_ai_schemas(self, monkeypatch): assert claude == {} # temp.erni.claude-* must not be bucketed assert gemini == [] + def test_preferred_prefix_promotes_matching_claude_model(self, monkeypatch): + payload = { + "model_services": [ + _model_service("system.ai.databricks-claude-haiku-4-5"), + _model_service("system.ai.companyname-claude-haiku-4-5"), + ] + } + monkeypatch.setattr( + db_mod, "_http_get_json", lambda url, token, timeout=10: (payload, None) + ) + monkeypatch.setenv("UCODE_MODEL_PREFERRED_PREFIX", "companyname-") + + claude, _, _, reason = db_mod.discover_model_services(WS, "token") + + assert reason is None + assert claude["haiku"] == "system.ai.companyname-claude-haiku-4-5" + def test_requests_bounded_page_size(self, monkeypatch): # The endpoint 499s without a bounded page_size, so every request must # carry one. @@ -445,9 +477,7 @@ def test_returns_newest_flash_first(self, monkeypatch): assert reason is None assert models[0] == "databricks-gemini-3-5-flash" - def test_codex_discovery_keeps_alphabetical_order(self, monkeypatch): - # Codex passes no sort_key, so ordering must stay the plain alphabetical - # default — guarding against the gemini change leaking across tools. + def test_codex_discovery_orders_newest_first(self, monkeypatch): payload = { "endpoints": [ { @@ -471,7 +501,46 @@ def test_codex_discovery_keeps_alphabetical_order(self, monkeypatch): models, reason = db_mod.discover_codex_models(WS, "token") assert reason is None - assert models == ["databricks-gpt-4-1", "databricks-gpt-5-2-codex"] + assert models == ["databricks-gpt-5-2-codex", "databricks-gpt-4-1"] + + def test_preferred_prefix_promotes_matching_gemini_model(self, monkeypatch): + payload = _foundation_models_payload( + ["databricks-gemini-3-5-flash", "companyname-gemini-3-5-flash"] + ) + monkeypatch.setattr(db_mod, "_http_get_json", lambda url, token: (payload, None)) + monkeypatch.setenv("UCODE_MODEL_PREFERRED_PREFIX", "companyname-") + + models, reason = db_mod.discover_gemini_models(WS, "token") + + assert reason is None + assert models[0] == "companyname-gemini-3-5-flash" + + def test_preferred_prefix_promotes_matching_codex_model(self, monkeypatch): + payload = { + "endpoints": [ + { + "name": name, + "config": { + "served_entities": [ + { + "foundation_model": { + "ai_gateway_v2_supported": True, + "api_types": ["openai/v1/responses"], + } + } + ] + }, + } + for name in ["databricks-gpt-5", "companyname-gpt-5"] + ] + } + monkeypatch.setattr(db_mod, "_http_get_json", lambda url, token: (payload, None)) + monkeypatch.setenv("UCODE_MODEL_PREFERRED_PREFIX", "companyname-") + + models, reason = db_mod.discover_codex_models(WS, "token") + + assert reason is None + assert models[0] == "companyname-gpt-5" class TestResolvePatToken: