diff --git a/wavefront/server/background_jobs/rag_ingestion/rag_ingestion/embeddings/image_embed.py b/wavefront/server/background_jobs/rag_ingestion/rag_ingestion/embeddings/image_embed.py index fd869371..5411af45 100644 --- a/wavefront/server/background_jobs/rag_ingestion/rag_ingestion/embeddings/image_embed.py +++ b/wavefront/server/background_jobs/rag_ingestion/rag_ingestion/embeddings/image_embed.py @@ -1,71 +1,48 @@ -import torch -from transformers import CLIPProcessor, CLIPModel, AutoImageProcessor, AutoModel -from PIL import Image -import io +import base64 + +import httpx +from flo_utils.utils.log import logger + +from rag_ingestion.env import INFERENCE_SERVICE_URL from rag_ingestion.models.knowledge_base_embeddings import KnowledgeBaseEmbeddingObject class ImageEmbedding: - def __init__(self): - self.device = 'cuda' if torch.cuda.is_available() else 'cpu' - print(f'Initializing models on device: {self.device}') - - # CLIP Model (Fixed: Added .to(self.device)) - self.clip_model_name = 'openai/clip-vit-base-patch32' - self.model = ( - CLIPModel.from_pretrained(self.clip_model_name).to(self.device).eval() - ) - self.processor = CLIPProcessor.from_pretrained(self.clip_model_name) + """Image embeddings via the inference service (CLIP + DINO).""" - # DINO Model (No change needed for device_map="auto") - self.dino_model_name = 'facebook/dinov3-vitl16-pretrain-lvd1689m' - self.dino_processor = AutoImageProcessor.from_pretrained(self.dino_model_name) - self.dino_model = AutoModel.from_pretrained( - self.dino_model_name, device_map='auto', trust_remote_code=True - ).eval() + def __init__(self): + if not INFERENCE_SERVICE_URL: + raise ValueError( + 'INFERENCE_SERVICE_URL must be set for image embedding API calls' + ) + base = INFERENCE_SERVICE_URL.rstrip('/') + self._embed_url = f'{base}/inference/v1/query/embeddings' + logger.info(f'Image embedding endpoint: {self._embed_url}') def embed_image(self, file_content: bytes) -> KnowledgeBaseEmbeddingObject: - image = Image.open(io.BytesIO(file_content)) - if image.mode != 'RGB': - image = image.convert('RGB') - - # CLIP Inputs (Fixed: Added .to(self.device)) - inputs = self.processor(images=image, return_tensors='pt').to(self.device) - - # --- CLIP EMBEDDING --- - with torch.no_grad(): - image_features = self.model.get_image_features(**inputs) - image_features = image_features / image_features.norm(dim=-1, keepdim=True) - embedding = image_features.squeeze().cpu().numpy().tolist() - - # --- DINO EMBEDDING CALL --- - dino_embedding = self.embed_image_dino(file_content) + payload = {'image_data': base64.b64encode(file_content).decode('ascii')} + response = httpx.post( + self._embed_url, + json=payload, + timeout=httpx.Timeout(120.0, connect=30.0), + ) + response.raise_for_status() + body = response.json() + embeddings = body.get('data', {}).get('response') + if not isinstance(embeddings, list) or len(embeddings) < 2: + raise ValueError( + f"Unexpected embedding response shape — expected list of at least 2 entries: {body!r}" + ) + + clip_entry, dino_entry = embeddings[0], embeddings[1] + if not isinstance(clip_entry, dict) or 'clip' not in clip_entry: + raise ValueError(f"Missing CLIP embedding in response entry: {clip_entry!r}") + if not isinstance(dino_entry, dict) or 'dino' not in dino_entry: + raise ValueError(f"Missing DINO embedding in response entry: {dino_entry!r}") - # Pass the DINO embedding to the correct field return KnowledgeBaseEmbeddingObject( - embedding_vector=embedding, - embedding_vector_1=dino_embedding, + embedding_vector=clip_entry['clip'], + embedding_vector_1=dino_entry['dino'], chunk_text='image data', chunk_index='chunk_0', ) - - @torch.inference_mode() - def embed_image_dino(self, file_content: bytes) -> list: - image = Image.open(io.BytesIO(file_content)) - if image.mode != 'RGB': - image = image.convert('RGB') - - inputs = self.dino_processor(images=image, return_tensors='pt') - - target_device = self.dino_model.device - # Move inputs to the DINO model's device - inputs = {k: v.to(target_device) for k, v in inputs.items()} - - outputs = self.dino_model(**inputs) - - image_features = outputs.last_hidden_state[:, 0] - - image_features = image_features / image_features.norm(dim=-1, keepdim=True) - embedding = image_features.squeeze().cpu().numpy().tolist() - - return embedding diff --git a/wavefront/server/background_jobs/rag_ingestion/rag_ingestion/env.py b/wavefront/server/background_jobs/rag_ingestion/rag_ingestion/env.py index 6de4001d..e875f556 100644 --- a/wavefront/server/background_jobs/rag_ingestion/rag_ingestion/env.py +++ b/wavefront/server/background_jobs/rag_ingestion/rag_ingestion/env.py @@ -7,7 +7,9 @@ RETRY_COUNT = os.getenv('RETRY_COUNT', 3) OPENAI_API_KEY = os.getenv('OPENAI_API_KEY') EMBEDDING_SERVICE_URL = os.getenv('EMBEDDING_SERVICE_URL') +INFERENCE_SERVICE_URL = os.getenv('INFERENCE_SERVICE_URL') FLOWARE_SERVICE_URL = os.getenv('FLOWARE_SERVICE_URL') APP_ENV = os.getenv('APP_ENV', 'dev') PASSTHROUGH_SECRET = os.getenv('PASSTHROUGH_SECRET') EMBEDDING_MODEL = os.getenv('EMBEDDING_MODEL', 'text-embedding-3-small') +STREAMING_BATCH_SIZE = int(os.getenv('STREAMING_BATCH_SIZE', 100)) diff --git a/wavefront/server/background_jobs/rag_ingestion/rag_ingestion/main.py b/wavefront/server/background_jobs/rag_ingestion/rag_ingestion/main.py index ee0441dd..c09867b6 100644 --- a/wavefront/server/background_jobs/rag_ingestion/rag_ingestion/main.py +++ b/wavefront/server/background_jobs/rag_ingestion/rag_ingestion/main.py @@ -2,7 +2,7 @@ from rag_ingestion.processors.kb_storage_processor import KbStorageProcessor from db_repo_module.cache.cache_manager import CacheManager from flo_cloud.kms import FloKmsService -from rag_ingestion.env import CLOUD_PROVIDER, RETRY_COUNT +from rag_ingestion.env import CLOUD_PROVIDER, RETRY_COUNT, STREAMING_BATCH_SIZE from flo_cloud.cloud_storage import CloudStorageManager from flo_cloud.message_queue import MessageQueueManager import os @@ -30,6 +30,7 @@ def main(): # Initialize stream listener listener = RagStreamListener( + streaming_batch_size=STREAMING_BATCH_SIZE, event_manager=event_manager, processor=KbStorageProcessor( storage_manager, diff --git a/wavefront/server/background_jobs/rag_ingestion/rag_ingestion/processors/kb_storage_processor.py b/wavefront/server/background_jobs/rag_ingestion/rag_ingestion/processors/kb_storage_processor.py index c6d7a70b..afc02b12 100644 --- a/wavefront/server/background_jobs/rag_ingestion/rag_ingestion/processors/kb_storage_processor.py +++ b/wavefront/server/background_jobs/rag_ingestion/rag_ingestion/processors/kb_storage_processor.py @@ -1,6 +1,7 @@ from flo_cloud.cloud_storage import CloudStorageManager +from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass -from typing import List +from typing import List, Tuple from flo_utils.utils.log import logger from rag_ingestion.service.kb_rag_storage import KBRagStorage from rag_ingestion.embeddings.embed import EmbeddingFunc @@ -54,6 +55,20 @@ async def _extract_content( ) return DocContent(content=content, document_type=document_type) + def __embed_single_insight( + self, kb_insight: ProcessingResult[KbStorageInsights] + ) -> Tuple[List[KnowledgeBaseEmbeddingObject], str, str, DocumentType]: + document_type = kb_insight.insights.doc_content.document_type + if document_type in (DocumentType.PDF, DocumentType.TEXT): + docs = self.kb_rag_storage.process_document( + [kb_insight.insights.doc_content.content] + ) + elif document_type == DocumentType.IMAGE: + docs = [self.image_embedding.embed_image(kb_insight.insights.doc_content.content)] + else: + docs = [] + return docs, kb_insight.insights.doc_id, kb_insight.insights.kb_id, document_type + def __insert_kb_from_message( self, insights: List[ProcessingResult[KbStorageInsights]] ): @@ -68,35 +83,23 @@ def __insert_kb_from_message( None """ try: - embeddings: List[EmbeddingsToStore] = [] - for kb_insight in insights: - kb_id = kb_insight.insights.kb_id - doc_id = kb_insight.insights.doc_id - document_type = kb_insight.insights.doc_content.document_type - - logger.info('Embeddings storing process is started') - if ( - document_type == DocumentType.PDF - or document_type == DocumentType.TEXT - ): - extracted_docs = [kb_insight.insights.doc_content.content] - docs: List[KnowledgeBaseEmbeddingObject] = ( - self.kb_rag_storage.process_document(extracted_docs) - ) - elif document_type == DocumentType.IMAGE: - image_data = [kb_insight.insights.doc_content.content] - docs: List[KnowledgeBaseEmbeddingObject] = [ - self.image_embedding.embed_image(image_data) - for image_data in image_data - ] - embeddings.append( - EmbeddingsToStore( - kb_embeddings=docs, - doc_id=doc_id, - kb_id=kb_id, - file_type=document_type, + logger.info('Embeddings storing process is started') + with ThreadPoolExecutor(max_workers=10) as executor: + futures = { + executor.submit(self.__embed_single_insight, kb_insight): kb_insight + for kb_insight in insights + } + embeddings: List[EmbeddingsToStore] = [] + for future in as_completed(futures): + docs, doc_id, kb_id, document_type = future.result() + embeddings.append( + EmbeddingsToStore( + kb_embeddings=docs, + doc_id=doc_id, + kb_id=kb_id, + file_type=document_type, + ) ) - ) self.kb_rag_storage.upload_embedding_with_retry(embeddings=embeddings) logger.info('Embeddings are stored in the db') diff --git a/wavefront/server/packages/flo_cloud/flo_cloud/_types/message_queue.py b/wavefront/server/packages/flo_cloud/flo_cloud/_types/message_queue.py index a3a9f705..615ab190 100644 --- a/wavefront/server/packages/flo_cloud/flo_cloud/_types/message_queue.py +++ b/wavefront/server/packages/flo_cloud/flo_cloud/_types/message_queue.py @@ -13,7 +13,7 @@ class MessageQueueDict: class MessageQueue(ABC): @abstractmethod def receive_messages( - self, max_messages=10, wait_time_sec=20 + self, max_messages: int = 10, wait_time_sec: int = 20 ) -> List[MessageQueueDict] | None: """ Receive messages from the event queue. diff --git a/wavefront/server/packages/flo_utils/flo_utils/streaming/stream_listner.py b/wavefront/server/packages/flo_utils/flo_utils/streaming/stream_listner.py index 767636b8..d1745ec3 100644 --- a/wavefront/server/packages/flo_utils/flo_utils/streaming/stream_listner.py +++ b/wavefront/server/packages/flo_utils/flo_utils/streaming/stream_listner.py @@ -19,12 +19,14 @@ def __init__( cache_manager: CacheManager, retry_count: int, streaming_batch_size: int = 5, + wait_time_sec: int = 20, ): self.event_manager = event_manager self.processor = processor self.cache_manager = cache_manager self.retry_count = retry_count self.streaming_batch_size = streaming_batch_size + self.wait_time_sec = wait_time_sec def handle_error( self, @@ -71,7 +73,8 @@ async def receive_queue_messages(self, worker_id: str): while True: try: response = self.event_manager.receive_messages( - max_messages=self.streaming_batch_size + max_messages=self.streaming_batch_size, + wait_time_sec=self.wait_time_sec ) messages: List[BaseEventMessage] = self.get_event_messages(response) logger.info(f'{worker_id}: listening for messages...')