Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,71 +1,48 @@
import torch
from transformers import CLIPProcessor, CLIPModel, AutoImageProcessor, AutoModel
from PIL import Image
import io
import base64

import httpx
from flo_utils.utils.log import logger

from rag_ingestion.env import INFERENCE_SERVICE_URL
from rag_ingestion.models.knowledge_base_embeddings import KnowledgeBaseEmbeddingObject


class ImageEmbedding:
def __init__(self):
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Initializing models on device: {self.device}')

# CLIP Model (Fixed: Added .to(self.device))
self.clip_model_name = 'openai/clip-vit-base-patch32'
self.model = (
CLIPModel.from_pretrained(self.clip_model_name).to(self.device).eval()
)
self.processor = CLIPProcessor.from_pretrained(self.clip_model_name)
"""Image embeddings via the inference service (CLIP + DINO)."""

# DINO Model (No change needed for device_map="auto")
self.dino_model_name = 'facebook/dinov3-vitl16-pretrain-lvd1689m'
self.dino_processor = AutoImageProcessor.from_pretrained(self.dino_model_name)
self.dino_model = AutoModel.from_pretrained(
self.dino_model_name, device_map='auto', trust_remote_code=True
).eval()
def __init__(self):
if not INFERENCE_SERVICE_URL:
raise ValueError(
'INFERENCE_SERVICE_URL must be set for image embedding API calls'
)
base = INFERENCE_SERVICE_URL.rstrip('/')
self._embed_url = f'{base}/inference/v1/query/embeddings'
logger.info(f'Image embedding endpoint: {self._embed_url}')

def embed_image(self, file_content: bytes) -> KnowledgeBaseEmbeddingObject:
image = Image.open(io.BytesIO(file_content))
if image.mode != 'RGB':
image = image.convert('RGB')

# CLIP Inputs (Fixed: Added .to(self.device))
inputs = self.processor(images=image, return_tensors='pt').to(self.device)

# --- CLIP EMBEDDING ---
with torch.no_grad():
image_features = self.model.get_image_features(**inputs)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
embedding = image_features.squeeze().cpu().numpy().tolist()

# --- DINO EMBEDDING CALL ---
dino_embedding = self.embed_image_dino(file_content)
payload = {'image_data': base64.b64encode(file_content).decode('ascii')}
response = httpx.post(
self._embed_url,
json=payload,
timeout=httpx.Timeout(120.0, connect=30.0),
)
response.raise_for_status()
Comment on lines +24 to +29
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Handle transient inference failures per item.

This remote call now sits inside the batch worker path, and wavefront/server/background_jobs/rag_ingestion/rag_ingestion/processors/kb_storage_processor.py:87-102 re-raises worker exceptions via future.result(). With the larger streaming batches, a single timeout/429/5xx here will fail the whole batch. Please add bounded retry/backoff for transient failures, or convert this into a per-item failure the caller can continue past.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@wavefront/server/background_jobs/rag_ingestion/rag_ingestion/embeddings/image_embed.py`
around lines 24 - 29, The HTTP call in image_embed.py (the httpx.post to
self._embed_url followed by response.raise_for_status) must not raise and fail
the whole batch on transient errors; implement a bounded retry with exponential
backoff for transient errors (timeouts, 429, 5xx) around the httpx.post, and if
retries still fail convert the result into a per-item failure instead of raising
— e.g., return/append an error marker or partial result for that image so the
caller (kb_storage_processor using future.result()) can continue processing
other items. Ensure you only retry idempotent/transient conditions, cap retries
and total wait, and use the function wrapping the post (the embedding method
that contains self._embed_url and response.raise_for_status) to surface per-item
errors rather than throwing.

body = response.json()
embeddings = body.get('data', {}).get('response')
if not isinstance(embeddings, list) or len(embeddings) < 2:
raise ValueError(
f"Unexpected embedding response shape — expected list of at least 2 entries: {body!r}"
)

clip_entry, dino_entry = embeddings[0], embeddings[1]
if not isinstance(clip_entry, dict) or 'clip' not in clip_entry:
raise ValueError(f"Missing CLIP embedding in response entry: {clip_entry!r}")
if not isinstance(dino_entry, dict) or 'dino' not in dino_entry:
raise ValueError(f"Missing DINO embedding in response entry: {dino_entry!r}")

# Pass the DINO embedding to the correct field
return KnowledgeBaseEmbeddingObject(
embedding_vector=embedding,
embedding_vector_1=dino_embedding,
embedding_vector=clip_entry['clip'],
embedding_vector_1=dino_entry['dino'],
chunk_text='image data',
chunk_index='chunk_0',
)

@torch.inference_mode()
def embed_image_dino(self, file_content: bytes) -> list:
image = Image.open(io.BytesIO(file_content))
if image.mode != 'RGB':
image = image.convert('RGB')

inputs = self.dino_processor(images=image, return_tensors='pt')

target_device = self.dino_model.device
# Move inputs to the DINO model's device
inputs = {k: v.to(target_device) for k, v in inputs.items()}

outputs = self.dino_model(**inputs)

image_features = outputs.last_hidden_state[:, 0]

image_features = image_features / image_features.norm(dim=-1, keepdim=True)
embedding = image_features.squeeze().cpu().numpy().tolist()

return embedding
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
RETRY_COUNT = os.getenv('RETRY_COUNT', 3)
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
EMBEDDING_SERVICE_URL = os.getenv('EMBEDDING_SERVICE_URL')
INFERENCE_SERVICE_URL = os.getenv('INFERENCE_SERVICE_URL')
FLOWARE_SERVICE_URL = os.getenv('FLOWARE_SERVICE_URL')
APP_ENV = os.getenv('APP_ENV', 'dev')
PASSTHROUGH_SECRET = os.getenv('PASSTHROUGH_SECRET')
EMBEDDING_MODEL = os.getenv('EMBEDDING_MODEL', 'text-embedding-3-small')
STREAMING_BATCH_SIZE = int(os.getenv('STREAMING_BATCH_SIZE', 100))
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from rag_ingestion.processors.kb_storage_processor import KbStorageProcessor
from db_repo_module.cache.cache_manager import CacheManager
from flo_cloud.kms import FloKmsService
from rag_ingestion.env import CLOUD_PROVIDER, RETRY_COUNT
from rag_ingestion.env import CLOUD_PROVIDER, RETRY_COUNT, STREAMING_BATCH_SIZE
from flo_cloud.cloud_storage import CloudStorageManager
from flo_cloud.message_queue import MessageQueueManager
import os
Expand Down Expand Up @@ -30,6 +30,7 @@ def main():

# Initialize stream listener
listener = RagStreamListener(
streaming_batch_size=STREAMING_BATCH_SIZE,
event_manager=event_manager,
processor=KbStorageProcessor(
storage_manager,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from flo_cloud.cloud_storage import CloudStorageManager
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from typing import List
from typing import List, Tuple
from flo_utils.utils.log import logger
from rag_ingestion.service.kb_rag_storage import KBRagStorage
from rag_ingestion.embeddings.embed import EmbeddingFunc
Expand Down Expand Up @@ -54,6 +55,20 @@ async def _extract_content(
)
return DocContent(content=content, document_type=document_type)

def __embed_single_insight(
self, kb_insight: ProcessingResult[KbStorageInsights]
) -> Tuple[List[KnowledgeBaseEmbeddingObject], str, str, DocumentType]:
document_type = kb_insight.insights.doc_content.document_type
if document_type in (DocumentType.PDF, DocumentType.TEXT):
docs = self.kb_rag_storage.process_document(
[kb_insight.insights.doc_content.content]
)
elif document_type == DocumentType.IMAGE:
docs = [self.image_embedding.embed_image(kb_insight.insights.doc_content.content)]
else:
docs = []
return docs, kb_insight.insights.doc_id, kb_insight.insights.kb_id, document_type

def __insert_kb_from_message(
self, insights: List[ProcessingResult[KbStorageInsights]]
):
Expand All @@ -68,35 +83,23 @@ def __insert_kb_from_message(
None
"""
try:
embeddings: List[EmbeddingsToStore] = []
for kb_insight in insights:
kb_id = kb_insight.insights.kb_id
doc_id = kb_insight.insights.doc_id
document_type = kb_insight.insights.doc_content.document_type

logger.info('Embeddings storing process is started')
if (
document_type == DocumentType.PDF
or document_type == DocumentType.TEXT
):
extracted_docs = [kb_insight.insights.doc_content.content]
docs: List[KnowledgeBaseEmbeddingObject] = (
self.kb_rag_storage.process_document(extracted_docs)
)
elif document_type == DocumentType.IMAGE:
image_data = [kb_insight.insights.doc_content.content]
docs: List[KnowledgeBaseEmbeddingObject] = [
self.image_embedding.embed_image(image_data)
for image_data in image_data
]
embeddings.append(
EmbeddingsToStore(
kb_embeddings=docs,
doc_id=doc_id,
kb_id=kb_id,
file_type=document_type,
logger.info('Embeddings storing process is started')
with ThreadPoolExecutor(max_workers=10) as executor:
futures = {
executor.submit(self.__embed_single_insight, kb_insight): kb_insight
for kb_insight in insights
}
embeddings: List[EmbeddingsToStore] = []
for future in as_completed(futures):
docs, doc_id, kb_id, document_type = future.result()
embeddings.append(
EmbeddingsToStore(
kb_embeddings=docs,
doc_id=doc_id,
kb_id=kb_id,
file_type=document_type,
)
)
)

self.kb_rag_storage.upload_embedding_with_retry(embeddings=embeddings)
logger.info('Embeddings are stored in the db')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class MessageQueueDict:
class MessageQueue(ABC):
@abstractmethod
def receive_messages(
self, max_messages=10, wait_time_sec=20
self, max_messages: int = 10, wait_time_sec: int = 20
) -> List[MessageQueueDict] | None:
"""
Receive messages from the event queue.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ def __init__(
cache_manager: CacheManager,
retry_count: int,
streaming_batch_size: int = 5,
wait_time_sec: int = 20,
):
self.event_manager = event_manager
self.processor = processor
self.cache_manager = cache_manager
self.retry_count = retry_count
self.streaming_batch_size = streaming_batch_size
self.wait_time_sec = wait_time_sec

def handle_error(
self,
Expand Down Expand Up @@ -71,7 +73,8 @@ async def receive_queue_messages(self, worker_id: str):
while True:
try:
response = self.event_manager.receive_messages(
max_messages=self.streaming_batch_size
max_messages=self.streaming_batch_size,
wait_time_sec=self.wait_time_sec
)
messages: List[BaseEventMessage] = self.get_event_messages(response)
logger.info(f'{worker_id}: listening for messages...')
Expand Down
Loading