Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 167 additions & 16 deletions src/logits/_rest_client.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,76 @@
from __future__ import annotations

from concurrent.futures import Future as ConcurrentFuture
from datetime import datetime
from typing import Literal

from tinker import types
from tinker._models import BaseModel as _SdkBaseModel
from tinker._types import NoneType
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
from tinker.lib.public_interfaces.api_future import AwaitableConcurrentFuture
from tinker.lib.public_interfaces.rest_client import RestClient as TinkerRestClient

from ._path_compat import normalize_tinker_path


class _BackendCheckpoint(_SdkBaseModel):
"""Mirror of the checkpoint object the Logits backend actually returns.

The upstream Tinker ``Checkpoint`` model requires ``tinker_path`` and uses
``size_bytes`` / ``public`` / ``expires_at``. The Logits backend instead
sends ``logits_path``, ``size``, ``visibility`` and ``expired_at`` (and a
human-readable ``name``). Because the upstream model marks ``tinker_path``
as required, the raw list call fails validation before we ever see the
data. We parse the backend's real field names here and translate back to the
upstream type so callers keep the documented ``Checkpoint`` shape.
"""

checkpoint_id: str
checkpoint_type: str | None = None
name: str | None = None
time: datetime | None = None
logits_path: str | None = None
tinker_path: str | None = None
size: int | None = None
size_bytes: int | None = None
visibility: str | None = None
public: bool | None = None
expired_at: datetime | None = None
expires_at: datetime | None = None


class _BackendCheckpointsList(_SdkBaseModel):
checkpoints: list[_BackendCheckpoint] = []
cursor: object | None = None


def _to_checkpoint(b: _BackendCheckpoint) -> types.Checkpoint:
public = b.public if b.public is not None else (b.visibility == "public")
return types.Checkpoint(
checkpoint_id=b.checkpoint_id,
checkpoint_type=b.checkpoint_type or "training", # type: ignore[arg-type]
time=b.time, # type: ignore[arg-type]
tinker_path=(b.tinker_path or b.logits_path or ""),
size_bytes=(b.size_bytes if b.size_bytes is not None else b.size),
public=bool(public),
expires_at=(b.expires_at or b.expired_at),
)


def _to_list_response(raw: _BackendCheckpointsList) -> types.CheckpointsListResponse:
cursor = None
if isinstance(raw.cursor, dict):
try:
cursor = types.Cursor(**raw.cursor)
except Exception:
cursor = None
return types.CheckpointsListResponse(
checkpoints=[_to_checkpoint(b) for b in raw.checkpoints],
cursor=cursor,
)


class RestClient(TinkerRestClient):
"""Logits-compatible RestClient.

Expand Down Expand Up @@ -41,6 +103,111 @@ def _parse_checkpoint_path(tinker_path: str) -> tuple[str, str]:
checkpoint_id = parsed.checkpoint_id.rsplit("/", 1)[-1]
return parsed.training_run_id, checkpoint_id

# ------------------------------------------------------------------
# Checkpoint listing (backend returns `logits_path`, not `tinker_path`)
# ------------------------------------------------------------------
async def _fetch_run_checkpoints(self, training_run_id: types.ModelID) -> _BackendCheckpointsList:
async def _send() -> _BackendCheckpointsList:
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
return await client.get(
f"/api/v1/training_runs/{training_run_id}/checkpoints",
cast_to=_BackendCheckpointsList,
)

return await self.holder.execute_with_retries(_send)

async def _fetch_user_checkpoints(self, limit: int, offset: int) -> _BackendCheckpointsList:
async def _send() -> _BackendCheckpointsList:
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
return await client.get(
"/api/v1/checkpoints",
options={"params": {"limit": limit, "offset": offset}},
cast_to=_BackendCheckpointsList,
)

return await self.holder.execute_with_retries(_send)

def _list_checkpoints_submit(
self, training_run_id: types.ModelID
) -> AwaitableConcurrentFuture[types.CheckpointsListResponse]:
async def _coro() -> types.CheckpointsListResponse:
return _to_list_response(await self._fetch_run_checkpoints(training_run_id))

return self.holder.run_coroutine_threadsafe(_coro())

def _list_user_checkpoints_submit(
self, limit: int = 100, offset: int = 0
) -> AwaitableConcurrentFuture[types.CheckpointsListResponse]:
async def _coro() -> types.CheckpointsListResponse:
return _to_list_response(await self._fetch_user_checkpoints(limit, offset))

return self.holder.run_coroutine_threadsafe(_coro())

# ------------------------------------------------------------------
# Publish / unpublish (backend resolves these routes by checkpoint_id,
# but the path the backend hands back embeds the human-readable name)
# ------------------------------------------------------------------
async def _resolve_checkpoint_id(self, training_run_id: types.ModelID, segment: str) -> str:
"""Map a path segment (name *or* checkpoint_id) to the real checkpoint_id.

The publish/unpublish routes only accept the opaque ``chk_...`` id, but
the ``logits_path`` the backend advertises (and that `save_state`
returns) embeds the checkpoint's name. Resolve the name to its id via
the run's checkpoint list; pass through anything that already looks like
an id so we never add a round-trip we don't need.
"""
if segment.startswith("chk_"):
return segment
raw = await self._fetch_run_checkpoints(training_run_id)
for b in raw.checkpoints:
if b.checkpoint_id == segment:
return segment
for b in raw.checkpoints:
if b.name == segment:
return b.checkpoint_id
return segment

def _resolve_and_set_publish_submit(
self, training_run_id: types.ModelID, segment: str, *, publish: bool
) -> AwaitableConcurrentFuture[None]:
async def _coro() -> None:
checkpoint_id = await self._resolve_checkpoint_id(training_run_id, segment)
url = f"/api/v1/training_runs/{training_run_id}/checkpoints/{checkpoint_id}/publish"

async def _send() -> None:
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
if publish:
await client.post(url, cast_to=NoneType)
else:
await client.delete(url, cast_to=NoneType)

await self.holder.execute_with_retries(_send)

return self.holder.run_coroutine_threadsafe(_coro())

def publish_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFuture[None]:
training_run_id, segment = self._parse_checkpoint_path(tinker_path)
return self._resolve_and_set_publish_submit(
training_run_id, segment, publish=True
).future()

async def publish_checkpoint_from_tinker_path_async(self, tinker_path: str) -> None:
training_run_id, segment = self._parse_checkpoint_path(tinker_path)
await self._resolve_and_set_publish_submit(training_run_id, segment, publish=True)

def unpublish_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFuture[None]:
training_run_id, segment = self._parse_checkpoint_path(tinker_path)
return self._resolve_and_set_publish_submit(
training_run_id, segment, publish=False
).future()

async def unpublish_checkpoint_from_tinker_path_async(self, tinker_path: str) -> None:
training_run_id, segment = self._parse_checkpoint_path(tinker_path)
await self._resolve_and_set_publish_submit(training_run_id, segment, publish=False)

# ------------------------------------------------------------------
# Path helpers unchanged from the original Logits shim
# ------------------------------------------------------------------
def get_training_run_by_tinker_path(
self, tinker_path: str, access_scope: Literal["owned", "accessible"] = "owned"
) -> ConcurrentFuture[types.TrainingRun]:
Expand Down Expand Up @@ -73,22 +240,6 @@ async def get_checkpoint_archive_url_from_tinker_path_async(
training_run_id, checkpoint_id = self._parse_checkpoint_path(tinker_path)
return await self._get_checkpoint_archive_url_submit(training_run_id, checkpoint_id)

def publish_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFuture[None]:
training_run_id, checkpoint_id = self._parse_checkpoint_path(tinker_path)
return self._publish_checkpoint_submit(training_run_id, checkpoint_id).future()

async def publish_checkpoint_from_tinker_path_async(self, tinker_path: str) -> None:
training_run_id, checkpoint_id = self._parse_checkpoint_path(tinker_path)
await self._publish_checkpoint_submit(training_run_id, checkpoint_id)

def unpublish_checkpoint_from_tinker_path(self, tinker_path: str) -> ConcurrentFuture[None]:
training_run_id, checkpoint_id = self._parse_checkpoint_path(tinker_path)
return self._unpublish_checkpoint_submit(training_run_id, checkpoint_id).future()

async def unpublish_checkpoint_from_tinker_path_async(self, tinker_path: str) -> None:
training_run_id, checkpoint_id = self._parse_checkpoint_path(tinker_path)
await self._unpublish_checkpoint_submit(training_run_id, checkpoint_id)

def set_checkpoint_ttl_from_tinker_path(
self, tinker_path: str, ttl_seconds: int | None
) -> ConcurrentFuture[None]:
Expand Down
65 changes: 64 additions & 1 deletion tests/test_rest_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@

import pytest

from logits._rest_client import RestClient
from logits._rest_client import (
RestClient,
_BackendCheckpoint,
_BackendCheckpointsList,
_to_checkpoint,
_to_list_response,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -31,3 +37,60 @@ def test_parse_checkpoint_path_strips_kind_prefix(
def test_parse_checkpoint_path_rejects_non_weights_uri() -> None:
with pytest.raises(ValueError):
RestClient._parse_checkpoint_path("https://example.com/model")


def test_backend_checkpoint_maps_renamed_fields() -> None:
# The Logits backend sends `logits_path` / `size` / `visibility` /
# `expired_at`; the upstream Checkpoint type requires `tinker_path` and uses
# `size_bytes` / `public` / `expires_at`. The raw list call would fail
# validation without this translation.
raw = _BackendCheckpoint.model_validate(
{
"checkpoint_id": "chk_abc",
"checkpoint_type": "training",
"name": "000010",
"time": "2026-06-09T23:02:26.201258-07:00",
"size": 810405661,
"visibility": "public",
"expired_at": None,
"logits_path": "logits://model_6320e11a/weights/000010",
}
)
ck = _to_checkpoint(raw)
assert ck.tinker_path == "logits://model_6320e11a/weights/000010"
assert ck.size_bytes == 810405661
assert ck.public is True
assert ck.expires_at is None
assert ck.checkpoint_type == "training"


def test_backend_checkpoint_private_visibility() -> None:
raw = _BackendCheckpoint.model_validate(
{
"checkpoint_id": "chk_def",
"checkpoint_type": "sampler",
"time": "2026-06-09T23:02:26.201258-07:00",
"visibility": "private",
"logits_path": "logits://model_x/sampler_weights/final",
}
)
assert _to_checkpoint(raw).public is False


def test_to_list_response_builds_upstream_type() -> None:
raw = _BackendCheckpointsList.model_validate(
{
"checkpoints": [
{
"checkpoint_id": "chk_1",
"checkpoint_type": "training",
"time": "2026-06-09T23:02:26.201258-07:00",
"logits_path": "logits://run/weights/a",
}
]
}
)
resp = _to_list_response(raw)
assert len(resp.checkpoints) == 1
assert resp.checkpoints[0].checkpoint_id == "chk_1"
assert resp.checkpoints[0].tinker_path == "logits://run/weights/a"
Loading