From 18ee8a82ad1e51a369a7d9cafbde4023840f789d Mon Sep 17 00:00:00 2001 From: w-angxu Date: Wed, 10 Jun 2026 15:02:17 +0800 Subject: [PATCH] Fix RestClient checkpoint listing and publish/unpublish against Logits backend The Logits backend diverged from the upstream Tinker schema/routes: - Checkpoint objects are returned with `logits_path` / `size` / `visibility` / `expired_at`, but the upstream `Checkpoint` model requires `tinker_path` and uses `size_bytes` / `public` / `expires_at`. `list_checkpoints` / `list_user_checkpoints` therefore raised APIResponseValidationError. Parse the backend's real field names via `_BackendCheckpoint` and translate back to the upstream `Checkpoint` type so callers keep the documented shape. - The publish/unpublish routes accept only the opaque `chk_...` checkpoint id, but the `logits_path` the backend advertises (and that `save_state` returns) embeds the human-readable checkpoint name. Resolve name -> checkpoint_id via the run's checkpoint list before calling publish (POST) / unpublish (DELETE). Verified end-to-end against api.logits.dev: list_checkpoints, publish and unpublish now work with the documented `logits:///weights/` paths. Adds unit tests for the field translation. Co-Authored-By: Claude Opus 4.8 --- src/logits/_rest_client.py | 183 +++++++++++++++++++++++++++++++++---- tests/test_rest_client.py | 65 ++++++++++++- 2 files changed, 231 insertions(+), 17 deletions(-) diff --git a/src/logits/_rest_client.py b/src/logits/_rest_client.py index 9945698..755e5c5 100644 --- a/src/logits/_rest_client.py +++ b/src/logits/_rest_client.py @@ -1,14 +1,76 @@ from __future__ import annotations from concurrent.futures import Future as ConcurrentFuture +from datetime import datetime from typing import Literal from tinker import types +from tinker._models import BaseModel as _SdkBaseModel +from tinker._types import NoneType +from tinker.lib.client_connection_pool_type import ClientConnectionPoolType +from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture from tinker.lib.public_interfaces.rest_client import RestClient as TinkerRestClient from ._path_compat import normalize_tinker_path +class _BackendCheckpoint(_SdkBaseModel): + """Mirror of the checkpoint object the Logits backend actually returns. + + The upstream Tinker ``Checkpoint`` model requires ``tinker_path`` and uses + ``size_bytes`` / ``public`` / ``expires_at``. The Logits backend instead + sends ``logits_path``, ``size``, ``visibility`` and ``expired_at`` (and a + human-readable ``name``). Because the upstream model marks ``tinker_path`` + as required, the raw list call fails validation before we ever see the + data. We parse the backend's real field names here and translate back to the + upstream type so callers keep the documented ``Checkpoint`` shape. + """ + + checkpoint_id: str + checkpoint_type: str | None = None + name: str | None = None + time: datetime | None = None + logits_path: str | None = None + tinker_path: str | None = None + size: int | None = None + size_bytes: int | None = None + visibility: str | None = None + public: bool | None = None + expired_at: datetime | None = None + expires_at: datetime | None = None + + +class _BackendCheckpointsList(_SdkBaseModel): + checkpoints: list[_BackendCheckpoint] = [] + cursor: object | None = None + + +def _to_checkpoint(b: _BackendCheckpoint) -> types.Checkpoint: + public = b.public if b.public is not None else (b.visibility == "public") + return types.Checkpoint( + checkpoint_id=b.checkpoint_id, + checkpoint_type=b.checkpoint_type or "training", # type: ignore[arg-type] + time=b.time, # type: ignore[arg-type] + tinker_path=(b.tinker_path or b.logits_path or ""), + size_bytes=(b.size_bytes if b.size_bytes is not None else b.size), + public=bool(public), + expires_at=(b.expires_at or b.expired_at), + ) + + +def _to_list_response(raw: _BackendCheckpointsList) -> types.CheckpointsListResponse: + cursor = None + if isinstance(raw.cursor, dict): + try: + cursor = types.Cursor(**raw.cursor) + except Exception: + cursor = None + return types.CheckpointsListResponse( + checkpoints=[_to_checkpoint(b) for b in raw.checkpoints], + cursor=cursor, + ) + + class RestClient(TinkerRestClient): """Logits-compatible RestClient. @@ -41,6 +103,111 @@ def _parse_checkpoint_path(tinker_path: str) -> tuple[str, str]: checkpoint_id = parsed.checkpoint_id.rsplit("/", 1)[-1] return parsed.training_run_id, checkpoint_id + # ------------------------------------------------------------------ + # Checkpoint listing (backend returns `logits_path`, not `tinker_path`) + # ------------------------------------------------------------------ + async def _fetch_run_checkpoints(self, training_run_id: types.ModelID) -> _BackendCheckpointsList: + async def _send() -> _BackendCheckpointsList: + with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client: + return await client.get( + f"/api/v1/training_runs/{training_run_id}/checkpoints", + cast_to=_BackendCheckpointsList, + ) + + return await self.holder.execute_with_retries(_send) + + async def _fetch_user_checkpoints(self, limit: int, offset: int) -> _BackendCheckpointsList: + async def _send() -> _BackendCheckpointsList: + with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client: + return await client.get( + "/api/v1/checkpoints", + options={"params": {"limit": limit, "offset": offset}}, + cast_to=_BackendCheckpointsList, + ) + + return await self.holder.execute_with_retries(_send) + + def _list_checkpoints_submit( + self, training_run_id: types.ModelID + ) -> AwaitableConcurrentFuture[types.CheckpointsListResponse]: + async def _coro() -> types.CheckpointsListResponse: + return _to_list_response(await self._fetch_run_checkpoints(training_run_id)) + + return self.holder.run_coroutine_threadsafe(_coro()) + + def _list_user_checkpoints_submit( + self, limit: int = 100, offset: int = 0 + ) -> AwaitableConcurrentFuture[types.CheckpointsListResponse]: + async def _coro() -> types.CheckpointsListResponse: + return _to_list_response(await self._fetch_user_checkpoints(limit, offset)) + + return self.holder.run_coroutine_threadsafe(_coro()) + + # ------------------------------------------------------------------ + # Publish / unpublish (backend resolves these routes by checkpoint_id, + # but the path the backend hands back embeds the human-readable name) + # ------------------------------------------------------------------ + async def _resolve_checkpoint_id(self, training_run_id: types.ModelID, segment: str) -> str: + """Map a path segment (name *or* checkpoint_id) to the real checkpoint_id. + + The publish/unpublish routes only accept the opaque ``chk_...`` id, but + the ``logits_path`` the backend advertises (and that `save_state` + returns) embeds the checkpoint's name. Resolve the name to its id via + the run's checkpoint list; pass through anything that already looks like + an id so we never add a round-trip we don't need. + """ + if segment.startswith("chk_"): + return segment + raw = await self._fetch_run_checkpoints(training_run_id) + for b in raw.checkpoints: + if b.checkpoint_id == segment: + return segment + for b in raw.checkpoints: + if b.name == segment: + return b.checkpoint_id + return segment + + def _resolve_and_set_publish_submit( + self, training_run_id: types.ModelID, segment: str, *, publish: bool + ) -> AwaitableConcurrentFuture[None]: + async def _coro() -> None: + checkpoint_id = await self._resolve_checkpoint_id(training_run_id, segment) + url = f"/api/v1/training_runs/{training_run_id}/checkpoints/{checkpoint_id}/publish" + + async def _send() -> None: + with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client: + if publish: + await client.post(url, cast_to=NoneType) + else: + await client.delete(url, cast_to=NoneType) + + await self.holder.execute_with_retries(_send) + + return self.holder.run_coroutine_threadsafe(_coro()) + + def publish_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFuture[None]: + training_run_id, segment = self._parse_checkpoint_path(tinker_path) + return self._resolve_and_set_publish_submit( + training_run_id, segment, publish=True + ).future() + + async def publish_checkpoint_from_tinker_path_async(self, tinker_path: str) -> None: + training_run_id, segment = self._parse_checkpoint_path(tinker_path) + await self._resolve_and_set_publish_submit(training_run_id, segment, publish=True) + + def unpublish_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFuture[None]: + training_run_id, segment = self._parse_checkpoint_path(tinker_path) + return self._resolve_and_set_publish_submit( + training_run_id, segment, publish=False + ).future() + + async def unpublish_checkpoint_from_tinker_path_async(self, tinker_path: str) -> None: + training_run_id, segment = self._parse_checkpoint_path(tinker_path) + await self._resolve_and_set_publish_submit(training_run_id, segment, publish=False) + + # ------------------------------------------------------------------ + # Path helpers unchanged from the original Logits shim + # ------------------------------------------------------------------ def get_training_run_by_tinker_path( self, tinker_path: str, access_scope: Literal["owned", "accessible"] = "owned" ) -> ConcurrentFuture[types.TrainingRun]: @@ -73,22 +240,6 @@ async def get_checkpoint_archive_url_from_tinker_path_async( training_run_id, checkpoint_id = self._parse_checkpoint_path(tinker_path) return await self._get_checkpoint_archive_url_submit(training_run_id, checkpoint_id) - def publish_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFuture[None]: - training_run_id, checkpoint_id = self._parse_checkpoint_path(tinker_path) - return self._publish_checkpoint_submit(training_run_id, checkpoint_id).future() - - async def publish_checkpoint_from_tinker_path_async(self, tinker_path: str) -> None: - training_run_id, checkpoint_id = self._parse_checkpoint_path(tinker_path) - await self._publish_checkpoint_submit(training_run_id, checkpoint_id) - - def unpublish_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFuture[None]: - training_run_id, checkpoint_id = self._parse_checkpoint_path(tinker_path) - return self._unpublish_checkpoint_submit(training_run_id, checkpoint_id).future() - - async def unpublish_checkpoint_from_tinker_path_async(self, tinker_path: str) -> None: - training_run_id, checkpoint_id = self._parse_checkpoint_path(tinker_path) - await self._unpublish_checkpoint_submit(training_run_id, checkpoint_id) - def set_checkpoint_ttl_from_tinker_path( self, tinker_path: str, ttl_seconds: int | None ) -> ConcurrentFuture[None]: diff --git a/tests/test_rest_client.py b/tests/test_rest_client.py index 49f9dd7..7daee30 100644 --- a/tests/test_rest_client.py +++ b/tests/test_rest_client.py @@ -2,7 +2,13 @@ import pytest -from logits._rest_client import RestClient +from logits._rest_client import ( + RestClient, + _BackendCheckpoint, + _BackendCheckpointsList, + _to_checkpoint, + _to_list_response, +) @pytest.mark.parametrize( @@ -31,3 +37,60 @@ def test_parse_checkpoint_path_strips_kind_prefix( def test_parse_checkpoint_path_rejects_non_weights_uri() -> None: with pytest.raises(ValueError): RestClient._parse_checkpoint_path("https://example.com/model") + + +def test_backend_checkpoint_maps_renamed_fields() -> None: + # The Logits backend sends `logits_path` / `size` / `visibility` / + # `expired_at`; the upstream Checkpoint type requires `tinker_path` and uses + # `size_bytes` / `public` / `expires_at`. The raw list call would fail + # validation without this translation. + raw = _BackendCheckpoint.model_validate( + { + "checkpoint_id": "chk_abc", + "checkpoint_type": "training", + "name": "000010", + "time": "2026-06-09T23:02:26.201258-07:00", + "size": 810405661, + "visibility": "public", + "expired_at": None, + "logits_path": "logits://model_6320e11a/weights/000010", + } + ) + ck = _to_checkpoint(raw) + assert ck.tinker_path == "logits://model_6320e11a/weights/000010" + assert ck.size_bytes == 810405661 + assert ck.public is True + assert ck.expires_at is None + assert ck.checkpoint_type == "training" + + +def test_backend_checkpoint_private_visibility() -> None: + raw = _BackendCheckpoint.model_validate( + { + "checkpoint_id": "chk_def", + "checkpoint_type": "sampler", + "time": "2026-06-09T23:02:26.201258-07:00", + "visibility": "private", + "logits_path": "logits://model_x/sampler_weights/final", + } + ) + assert _to_checkpoint(raw).public is False + + +def test_to_list_response_builds_upstream_type() -> None: + raw = _BackendCheckpointsList.model_validate( + { + "checkpoints": [ + { + "checkpoint_id": "chk_1", + "checkpoint_type": "training", + "time": "2026-06-09T23:02:26.201258-07:00", + "logits_path": "logits://run/weights/a", + } + ] + } + ) + resp = _to_list_response(raw) + assert len(resp.checkpoints) == 1 + assert resp.checkpoints[0].checkpoint_id == "chk_1" + assert resp.checkpoints[0].tinker_path == "logits://run/weights/a"