Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion rationai/resources/models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Any

import lz4.frame
import numpy as np
from httpx import USE_CLIENT_DEFAULT
from httpx._client import UseClientDefault
from httpx._types import TimeoutTypes
from numpy.typing import NDArray
from numpy.typing import DTypeLike, NDArray
from PIL.Image import Image

from rationai._resource import APIResource, AsyncAPIResource
Expand Down Expand Up @@ -64,6 +66,30 @@ def segment_image(
lz4.frame.decompress(response.content), dtype=np.float16
).reshape(-1, h, w)

def embed_image(
self,
model: str,
image: Image | NDArray[np.uint8],
output_dtype: DTypeLike = np.float32,
timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT,
) -> NDArray[np.floating[Any]]:
"""Compute an embedding vector for an image using the specified model.

Args:
model: The name of the model to use for embedding.
image: The image to embed. It must be uint8 RGB image.
output_dtype: Output numpy dtype for embeddings (e.g. np.float16, np.float32).
timeout: Optional timeout for the request.

Returns:
NDArray[np.floating[Any]]: The embedding vector as a 1-D numpy array.
"""
data = image.tobytes()
compressed_data = lz4.frame.compress(data)
response = self._post(model, data=compressed_data, timeout=timeout)
response.raise_for_status()
return np.array(response.json(), dtype=output_dtype)


class AsyncModels(AsyncAPIResource):
async def classify_image(
Expand Down Expand Up @@ -119,3 +145,27 @@ async def segment_image(
return np.frombuffer(
lz4.frame.decompress(response.content), dtype=np.float16
).reshape(-1, h, w)

async def embed_image(
self,
model: str,
image: Image | NDArray[np.uint8],
output_dtype: DTypeLike = np.float32,
timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT,
) -> NDArray[np.floating[Any]]:
"""Compute an embedding vector for an image using the specified model.

Args:
model: The name of the model to use for embedding.
image: The image to embed. It must be uint8 RGB image.
output_dtype: Output numpy dtype for embeddings (e.g. np.float16, np.float32).
timeout: Optional timeout for the request.

Returns:
NDArray[np.floating[Any]]: The embedding vector as a 1-D numpy array.
"""
data = image.tobytes()
compressed_data = lz4.frame.compress(data)
response = await self._post(model, data=compressed_data, timeout=timeout)
response.raise_for_status()
return np.array(response.json(), dtype=output_dtype)
Loading