diff --git a/rationai/resources/models.py b/rationai/resources/models.py index 47be222..f023cfd 100644 --- a/rationai/resources/models.py +++ b/rationai/resources/models.py @@ -1,9 +1,11 @@ +from typing import Any + import lz4.frame import numpy as np from httpx import USE_CLIENT_DEFAULT from httpx._client import UseClientDefault from httpx._types import TimeoutTypes -from numpy.typing import NDArray +from numpy.typing import DTypeLike, NDArray from PIL.Image import Image from rationai._resource import APIResource, AsyncAPIResource @@ -64,6 +66,30 @@ def segment_image( lz4.frame.decompress(response.content), dtype=np.float16 ).reshape(-1, h, w) + def embed_image( + self, + model: str, + image: Image | NDArray[np.uint8], + output_dtype: DTypeLike = np.float32, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + ) -> NDArray[np.floating[Any]]: + """Compute an embedding vector for an image using the specified model. + + Args: + model: The name of the model to use for embedding. + image: The image to embed. It must be uint8 RGB image. + output_dtype: Output numpy dtype for embeddings (e.g. np.float16, np.float32). + timeout: Optional timeout for the request. + + Returns: + NDArray[np.floating[Any]]: The embedding vector as a 1-D numpy array. + """ + data = image.tobytes() + compressed_data = lz4.frame.compress(data) + response = self._post(model, data=compressed_data, timeout=timeout) + response.raise_for_status() + return np.array(response.json(), dtype=output_dtype) + class AsyncModels(AsyncAPIResource): async def classify_image( @@ -119,3 +145,27 @@ async def segment_image( return np.frombuffer( lz4.frame.decompress(response.content), dtype=np.float16 ).reshape(-1, h, w) + + async def embed_image( + self, + model: str, + image: Image | NDArray[np.uint8], + output_dtype: DTypeLike = np.float32, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + ) -> NDArray[np.floating[Any]]: + """Compute an embedding vector for an image using the specified model. + + Args: + model: The name of the model to use for embedding. + image: The image to embed. It must be uint8 RGB image. + output_dtype: Output numpy dtype for embeddings (e.g. np.float16, np.float32). + timeout: Optional timeout for the request. + + Returns: + NDArray[np.floating[Any]]: The embedding vector as a 1-D numpy array. + """ + data = image.tobytes() + compressed_data = lz4.frame.compress(data) + response = await self._post(model, data=compressed_data, timeout=timeout) + response.raise_for_status() + return np.array(response.json(), dtype=output_dtype)