From c6e48b1d044e3a332404966775b020d4b5d1e49e Mon Sep 17 00:00:00 2001 From: Pablo Gonzalez Date: Mon, 25 May 2026 13:03:09 -0500 Subject: [PATCH 1/2] Remove old submodule --- .gitmodules | 3 --- language/bert/DeepLearningExamples | 1 - 2 files changed, 4 deletions(-) delete mode 160000 language/bert/DeepLearningExamples diff --git a/.gitmodules b/.gitmodules index 184d1c83ad..ffed9043cd 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,3 @@ -[submodule "language/bert/DeepLearningExamples"] - path = language/bert/DeepLearningExamples - url = https://github.com/NVIDIA/DeepLearningExamples.git [submodule "vision/medical_imaging/3d-unet-brats19/nnUnet"] path = vision/medical_imaging/3d-unet-brats19/nnUnet url = https://github.com/MIC-DKFZ/nnUNet.git diff --git a/language/bert/DeepLearningExamples b/language/bert/DeepLearningExamples deleted file mode 160000 index b03375bd6c..0000000000 --- a/language/bert/DeepLearningExamples +++ /dev/null @@ -1 +0,0 @@ -Subproject commit b03375bd6c2c5233130e61a3be49e26d1a20ac7c From 41b7b5e1c5c0626046a8880e6280065f6ade2275 Mon Sep 17 00:00:00 2001 From: mlc-automations <3246381+mlc-automations@users.noreply.github.com> Date: Thu, 25 Jun 2026 15:37:09 +0000 Subject: [PATCH 2/2] [Automated Commit] Format Codebase --- e2e-rag/QSL.py | 11 +- e2e-rag/accuracy_eval.py | 76 ++- e2e-rag/datasetup_accuracy_eval.py | 100 ++-- e2e-rag/db_manifest.py | 26 +- e2e-rag/download_docs.py | 308 +++++----- e2e-rag/evaluate.py | 72 ++- e2e-rag/evaluation.py | 509 +++++++++------- e2e-rag/ingestion_monitor.py | 236 ++++---- e2e-rag/llm_logger.py | 46 +- e2e-rag/measure_indexing_with_chunking.py | 20 +- e2e-rag/multi_shot_retrieval.py | 669 +++++++++++++--------- e2e-rag/oracle_single_shot.py | 176 +++--- e2e-rag/params.py | 127 ++-- e2e-rag/perf_test_cache.py | 6 +- e2e-rag/read_docs.py | 297 +++++----- e2e-rag/reference_SUT.py | 15 +- e2e-rag/reference_SUT_datasetup.py | 69 ++- e2e-rag/reference_mlperf.py | 27 +- e2e-rag/reference_mlperf_datasetup.py | 20 +- e2e-rag/reranker_worker.py | 15 +- e2e-rag/retrieve/__init__.py | 2 +- e2e-rag/retrieve/filter.py | 149 ++--- e2e-rag/retrieve/ragdb.py | 91 +-- e2e-rag/retrieve/vectordb.py | 300 ++++++---- e2e-rag/single_shot_retrieval.py | 109 ++-- e2e-rag/text_splitter.py | 59 +- e2e-rag/utils.py | 88 +-- loadgen/issue_query_controller.cc | 12 +- loadgen/logging.cc | 3 +- 29 files changed, 2129 insertions(+), 1509 deletions(-) diff --git a/e2e-rag/QSL.py b/e2e-rag/QSL.py index 93e5f7933c..164c25c693 100644 --- a/e2e-rag/QSL.py +++ b/e2e-rag/QSL.py @@ -85,7 +85,8 @@ def __init__(self, dataset_path, perf_count=None, skip_qsl=False): print(f"Dataset loaded: {self.count} queries") if perf_count is not None: - print(f" (limited to first {perf_count} queries for performance testing)") + print( + f" (limited to first {perf_count} queries for performance testing)") def load_query_samples(self, sample_list): """ @@ -161,15 +162,12 @@ def __init__(self, dataset_path, perf_count=None): # limitations under the License. # ============================================================================= + """ Query Sample Library for RAG-QnA workload. Loads queries from frames_dataset.tsv and provides them to MLPerf Loadgen. """ -import os -import pandas as pd -import mlperf_loadgen as lg - class E2EQSL: """Query Sample Library for RAG-QnA multi-hop RAG benchmark.""" @@ -233,7 +231,8 @@ def __init__(self, dataset_path, perf_count=None, skip_qsl=False): print(f"Dataset loaded: {self.count} queries") if perf_count is not None: - print(f" (limited to first {perf_count} queries for performance testing)") + print( + f" (limited to first {perf_count} queries for performance testing)") def load_query_samples(self, sample_list): """ diff --git a/e2e-rag/accuracy_eval.py b/e2e-rag/accuracy_eval.py index 29ce199482..a4561c7449 100644 --- a/e2e-rag/accuracy_eval.py +++ b/e2e-rag/accuracy_eval.py @@ -34,9 +34,10 @@ # OpenRouter configuration DEFAULT_JUDGE_URL = "http://127.0.0.1:8123/v1/chat/completions" DEFAULT_JUDGE_MODEL = "gpt-oss-20b" -# Masked API key (set OPENROUTER_API_KEY environment variable to use OpenRouter) +# Masked API key (set OPENROUTER_API_KEY environment variable to use +# OpenRouter) OPENROUTER_API_KEY = os.environ.get('OPENROUTER_API_KEY', - 'sk-or-v1-****') + 'sk-or-v1-****') JUDGE_PROMPT = """You are an expert evaluator comparing LLM-generated answers to ground truth answers. @@ -83,7 +84,11 @@ def call_judge(question: str, ground_truth: str, llm_answer: str, } try: - response = requests.post(service_url, json=payload, headers=headers, timeout=60) + response = requests.post( + service_url, + json=payload, + headers=headers, + timeout=60) response.raise_for_status() result = response.json() @@ -105,7 +110,8 @@ def call_judge(question: str, ground_truth: str, llm_answer: str, return {"correct": False, "reasoning": f"Judge error: {e}"} -def calculate_retrieval_metrics(retrieved_urls: List[str], expected_urls: List[str]) -> Dict: +def calculate_retrieval_metrics( + retrieved_urls: List[str], expected_urls: List[str]) -> Dict: """Calculate precision, recall, F1 for retrieval.""" retrieved_set = set(retrieved_urls) @@ -118,7 +124,8 @@ def calculate_retrieval_metrics(retrieved_urls: List[str], expected_urls: List[s precision = len(correct) / len(retrieved_set) if retrieved_set else 0.0 recall = len(correct) / len(expected_set) if expected_set else 0.0 - f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 + f1 = 2 * precision * recall / \ + (precision + recall) if (precision + recall) > 0 else 0.0 return { "precision": precision, @@ -128,8 +135,8 @@ def calculate_retrieval_metrics(retrieved_urls: List[str], expected_urls: List[s def evaluate_results(results: Dict, dataset_path: str, num_workers: int = 4, - judge_service_url: str = DEFAULT_JUDGE_URL, - judge_model: str = DEFAULT_JUDGE_MODEL) -> Dict: + judge_service_url: str = DEFAULT_JUDGE_URL, + judge_model: str = DEFAULT_JUDGE_MODEL) -> Dict: """ Evaluate loadgen results. @@ -190,12 +197,13 @@ def evaluate_single_query(query_id, result): expected_urls = gt_data['expected_urls'] # Calculate retrieval metrics - retrieval_metrics = calculate_retrieval_metrics(retrieved_urls, expected_urls) + retrieval_metrics = calculate_retrieval_metrics( + retrieved_urls, expected_urls) # Judge answer correctness judge_result = call_judge(query, ground_truth, llm_answer, - service_url=judge_service_url, - model_name=judge_model) + service_url=judge_service_url, + model_name=judge_model) answer_correct = judge_result.get('correct', False) return { @@ -229,7 +237,8 @@ def evaluate_single_query(query_id, result): total_queries += 1 if total_queries % 10 == 0: - print(f" Evaluated {total_queries}/{len(results)} queries...") + print( + f" Evaluated {total_queries}/{len(results)} queries...") except Exception as e: print(f"Error evaluating query: {e}") @@ -257,14 +266,37 @@ def evaluate_single_query(query_id, result): def main(): - parser = argparse.ArgumentParser(description="Evaluate RAG-QnA loadgen accuracy") - parser.add_argument('--log_dir', required=True, help='Loadgen log directory') - parser.add_argument('--results_file', required=True, help='SUT results JSON file') - parser.add_argument('--dataset_path', required=True, help='Path to frames_dataset.tsv') - parser.add_argument('--num_workers', type=int, default=4, help='Number of parallel judge workers') - parser.add_argument('--output', default='accuracy_results.json', help='Output file for detailed results') - parser.add_argument('--judge_service_url', default=DEFAULT_JUDGE_URL, help='Judge LLM service URL') - parser.add_argument('--judge_model', default=DEFAULT_JUDGE_MODEL, help='Judge LLM model name') + parser = argparse.ArgumentParser( + description="Evaluate RAG-QnA loadgen accuracy") + parser.add_argument( + '--log_dir', + required=True, + help='Loadgen log directory') + parser.add_argument( + '--results_file', + required=True, + help='SUT results JSON file') + parser.add_argument( + '--dataset_path', + required=True, + help='Path to frames_dataset.tsv') + parser.add_argument( + '--num_workers', + type=int, + default=4, + help='Number of parallel judge workers') + parser.add_argument( + '--output', + default='accuracy_results.json', + help='Output file for detailed results') + parser.add_argument( + '--judge_service_url', + default=DEFAULT_JUDGE_URL, + help='Judge LLM service URL') + parser.add_argument( + '--judge_model', + default=DEFAULT_JUDGE_MODEL, + help='Judge LLM model name') args = parser.parse_args() # Load results @@ -280,9 +312,9 @@ def main(): judge_model=args.judge_model) # Print summary - print("\n" + "="*80) + print("\n" + "=" * 80) print("ACCURACY EVALUATION RESULTS") - print("="*80) + print("=" * 80) print(f"Total Queries: {metrics['total_queries']}") print(f"\nRetrieval Metrics:") print(f" Precision@N: {metrics['retrieval_precision']:.3f}") @@ -290,7 +322,7 @@ def main(): print(f" F1@N: {metrics['retrieval_f1']:.3f}") print(f"\nAnswer Quality:") print(f" LLM Judge Accuracy: {metrics['answer_accuracy']:.3f}") - print("="*80 + "\n") + print("=" * 80 + "\n") # Save detailed results with open(args.output, 'w') as f: diff --git a/e2e-rag/datasetup_accuracy_eval.py b/e2e-rag/datasetup_accuracy_eval.py index 7df71d03f0..3f89041b46 100755 --- a/e2e-rag/datasetup_accuracy_eval.py +++ b/e2e-rag/datasetup_accuracy_eval.py @@ -68,7 +68,8 @@ def parse_accuracy_log(log_path): try: qsl_idx = entry['qsl_idx'] - # Handle data field - can be hex string or list of bytes + # Handle data field - can be hex string or list of + # bytes data_field = entry['data'] if isinstance(data_field, str): # Hex string - convert to bytes @@ -180,7 +181,8 @@ def validate_database(database_path, retriever_model): print(f" ✗ Vector count: Cannot access vector store") # Check 2: Docstore consistency - if hasattr(db, '_vector_store') and hasattr(db._vector_store, 'index_to_docstore_id'): + if hasattr(db, '_vector_store') and hasattr( + db._vector_store, 'index_to_docstore_id'): docstore_count = len(db._vector_store.index_to_docstore_id) check_passed = (docstore_count == vector_count) validation_results["checks"].append({ @@ -194,7 +196,8 @@ def validate_database(database_path, retriever_model): if check_passed: print(f" ✓ Docstore consistency: {docstore_count} documents") else: - print(f" ✗ Docstore consistency: {vector_count} vectors but {docstore_count} documents") + print( + f" ✗ Docstore consistency: {vector_count} vectors but {docstore_count} documents") else: validation_results["checks"].append({ "name": "docstore_consistency", @@ -206,7 +209,8 @@ def validate_database(database_path, retriever_model): print(f" ✗ Docstore consistency: Cannot access docstore") # Check 3: Index dimension - if hasattr(db, '_vector_store') and hasattr(db._vector_store.index, 'd'): + if hasattr(db, '_vector_store') and hasattr( + db._vector_store.index, 'd'): dimension = db._vector_store.index.d expected_dim = 768 # e5-base-v2 check_passed = (dimension == expected_dim) @@ -218,7 +222,8 @@ def validate_database(database_path, retriever_model): "expected": expected_dim }) # Don't fail on dimension mismatch, just warn - print(f" {'✓' if check_passed else '⚠'} Index dimension: {dimension} (expected {expected_dim})") + print( + f" {'✓' if check_passed else '⚠'} Index dimension: {dimension} (expected {expected_dim})") else: validation_results["checks"].append({ "name": "index_dimension", @@ -241,7 +246,8 @@ def validate_database(database_path, retriever_model): }) validation_results["passed"] &= check_passed if check_passed: - print(f" ✓ Sample retrieval: Retrieved {len(results)} results") + print( + f" ✓ Sample retrieval: Retrieved {len(results)} results") else: print(f" ✗ Sample retrieval: No results returned") except Exception as e: @@ -268,7 +274,8 @@ def validate_database(database_path, retriever_model): return validation_results -def evaluate_accuracy(log_dir, output_dir, database_path, retriever_model=None): +def evaluate_accuracy(log_dir, output_dir, database_path, + retriever_model=None): """ Evaluate accuracy of datasetup workload. @@ -281,9 +288,9 @@ def evaluate_accuracy(log_dir, output_dir, database_path, retriever_model=None): Returns: dict: Accuracy results """ - print("="*80) + print("=" * 80) print("RAG-DB Accuracy Evaluation") - print("="*80) + print("=" * 80) print(f"Started: {datetime.now().isoformat()}") print() @@ -320,18 +327,21 @@ def evaluate_accuracy(log_dir, output_dir, database_path, retriever_model=None): # MD5 hash (32 hex characters = 32 bytes when encoded) try: md5_str = data.decode('utf-8') - if len(md5_str) == 32 and all(c in '0123456789abcdef' for c in md5_str): + if len(md5_str) == 32 and all( + c in '0123456789abcdef' for c in md5_str): md5_response = md5_str md5_qsl_idx = qsl_idx success_count += 1 # MD5 response counts as success else: - print(f"Warning: Invalid MD5 format at qsl_idx {qsl_idx}: {md5_str}") + print( + f"Warning: Invalid MD5 format at qsl_idx {qsl_idx}: {md5_str}") failure_count += 1 except UnicodeDecodeError: print(f"Warning: Cannot decode response at qsl_idx {qsl_idx}") failure_count += 1 else: - print(f"Warning: Unexpected response length {len(data)} at qsl_idx {qsl_idx}") + print( + f"Warning: Unexpected response length {len(data)} at qsl_idx {qsl_idx}") failure_count += 1 print(f"Response Summary:") @@ -372,13 +382,14 @@ def evaluate_accuracy(log_dir, output_dir, database_path, retriever_model=None): # Compare MD5s md5_match = (md5_response == actual_md5) - print("="*80) + print("=" * 80) print("Accuracy Results:") - print("="*80) + print("=" * 80) print(f" Total files processed: {success_count + failure_count}") print(f" Successful: {success_count}") print(f" Failed: {failure_count}") - print(f" Success rate: {100.0 * success_count / (success_count + failure_count):.2f}%") + print( + f" Success rate: {100.0 * success_count / (success_count + failure_count):.2f}%") print() print(f" MD5 returned: {md5_response}") print(f" MD5 actual: {actual_md5}") @@ -390,12 +401,13 @@ def evaluate_accuracy(log_dir, output_dir, database_path, retriever_model=None): # 1. At least 99% of files succeeded # 2. MD5 matches min_success_rate = 0.99 - actual_success_rate = success_count / (success_count + failure_count) if (success_count + failure_count) > 0 else 0 + actual_success_rate = success_count / \ + (success_count + failure_count) if (success_count + failure_count) > 0 else 0 passed = (actual_success_rate >= min_success_rate) and md5_match print(f"Overall: {'✅ PASSED' if passed else '❌ FAILED'}") - print("="*80) + print("=" * 80) print() # Save results @@ -423,9 +435,9 @@ def evaluate_accuracy(log_dir, output_dir, database_path, retriever_model=None): # Run database validation if retriever model provided validation_results = None if retriever_model and os.path.exists(database_path): - print("="*80) + print("=" * 80) print("Database Validation") - print("="*80) + print("=" * 80) print() validation_results = validate_database(database_path, retriever_model) @@ -436,48 +448,54 @@ def evaluate_accuracy(log_dir, output_dir, database_path, retriever_model=None): overall_passed = passed and validation_results["passed"] accuracy_results["passed"] = overall_passed - print("="*80) + print("=" * 80) print("Validation Summary:") - print("="*80) + print("=" * 80) for check in validation_results["checks"]: - status_symbol = "✓" if check["result"] == "PASS" else ("⚠" if check["result"] == "WARN" else "✗") - print(f" {status_symbol} {check['description']}: {check['result']}") + status_symbol = "✓" if check["result"] == "PASS" else ( + "⚠" if check["result"] == "WARN" else "✗") + print( + f" {status_symbol} {check['description']}: {check['result']}") print() - print(f"Validation: {'✅ PASSED' if validation_results['passed'] else '❌ FAILED'}") + print( + f"Validation: {'✅ PASSED' if validation_results['passed'] else '❌ FAILED'}") print(f"Overall: {'✅ PASSED' if overall_passed else '❌ FAILED'}") - print("="*80) + print("=" * 80) print() # Write accuracy.txt in MLPerf format accuracy_txt_path = os.path.join(log_dir, "accuracy.txt") with open(accuracy_txt_path, 'w') as f: - f.write("="*80 + "\n") + f.write("=" * 80 + "\n") f.write("RAG-DB Accuracy Report\n") - f.write("="*80 + "\n") + f.write("=" * 80 + "\n") f.write(f"Timestamp: {datetime.now().isoformat()}\n") f.write(f"Database: {database_path}\n") f.write("\n") f.write("File Processing Results:\n") - f.write("-"*80 + "\n") + f.write("-" * 80 + "\n") f.write(f"Total files: {accuracy_results['total_files']}\n") f.write(f"Successful: {accuracy_results['successful_files']}\n") f.write(f"Failed: {accuracy_results['failed_files']}\n") f.write(f"Success rate: {accuracy_results['success_rate']*100:.2f}%\n") - f.write(f"Required rate: {accuracy_results['min_success_rate_required']*100:.2f}%\n") - f.write(f"Status: {'PASS' if actual_success_rate >= min_success_rate else 'FAIL'}\n") + f.write( + f"Required rate: {accuracy_results['min_success_rate_required']*100:.2f}%\n") + f.write( + f"Status: {'PASS' if actual_success_rate >= min_success_rate else 'FAIL'}\n") f.write("\n") f.write("MD5 Verification:\n") - f.write("-"*80 + "\n") + f.write("-" * 80 + "\n") f.write(f"MD5 returned by SUT: {accuracy_results['md5_response']}\n") f.write(f"MD5 actual (computed): {accuracy_results['md5_actual']}\n") - f.write(f"MD5 match: {'PASS' if accuracy_results['md5_match'] else 'FAIL'}\n") + f.write( + f"MD5 match: {'PASS' if accuracy_results['md5_match'] else 'FAIL'}\n") f.write("\n") if validation_results: f.write("Database Validation:\n") - f.write("-"*80 + "\n") + f.write("-" * 80 + "\n") for check in validation_results["checks"]: f.write(f" {check['name']}: {check['result']}\n") f.write(f" - {check['description']}\n") @@ -485,12 +503,14 @@ def evaluate_accuracy(log_dir, output_dir, database_path, retriever_model=None): f.write(f" - Value: {check['value']}\n") if 'error' in check: f.write(f" - Error: {check['error']}\n") - f.write(f"Validation status: {'PASS' if validation_results['passed'] else 'FAIL'}\n") + f.write( + f"Validation status: {'PASS' if validation_results['passed'] else 'FAIL'}\n") f.write("\n") - f.write("="*80 + "\n") - f.write(f"Overall Result: {'PASS' if accuracy_results['passed'] else 'FAIL'}\n") - f.write("="*80 + "\n") + f.write("=" * 80 + "\n") + f.write( + f"Overall Result: {'PASS' if accuracy_results['passed'] else 'FAIL'}\n") + f.write("=" * 80 + "\n") print(f"Accuracy report saved to: {accuracy_txt_path}") print() @@ -526,7 +546,11 @@ def main(): args = parser.parse_args() - results = evaluate_accuracy(args.log_dir, args.output_dir, args.database, args.retriever_model) + results = evaluate_accuracy( + args.log_dir, + args.output_dir, + args.database, + args.retriever_model) # Exit with appropriate code if results.get("passed", False): diff --git a/e2e-rag/db_manifest.py b/e2e-rag/db_manifest.py index 1cfe20f85f..39f8cbbc0a 100644 --- a/e2e-rag/db_manifest.py +++ b/e2e-rag/db_manifest.py @@ -112,7 +112,8 @@ def _gather_top_k(db: VectorDB, queries: List[Dict], k: int) -> List[Dict]: urls = [] for doc in results: md = getattr(doc, "metadata", None) or {} - url = md.get("original_url") or md.get("source") or md.get("base_filename") or "" + url = md.get("original_url") or md.get( + "source") or md.get("base_filename") or "" urls.append(url) out.append({"index": q["index"], "top_k_urls": urls}) return out @@ -139,10 +140,12 @@ def cmd_write(args): db = _load_db(args.db, args.retriever_model) total_passages = len(db._vector_store.index_to_docstore_id) - print(f"[manifest] DB has {total_passages} passages, dim={db._embedding_dimension}") + print( + f"[manifest] DB has {total_passages} passages, dim={db._embedding_dimension}") corpus_sha = _sha256_docstore(db) - sample_block = _gather_sample_embeddings(db, total_passages, NUM_SAMPLE_EMBEDDINGS) + sample_block = _gather_sample_embeddings( + db, total_passages, NUM_SAMPLE_EMBEDDINGS) probe_queries = _load_probe_queries(args.dataset, NUM_PROBE_QUERIES) probe_block = _gather_top_k(db, probe_queries, PROBE_TOP_K) @@ -237,7 +240,9 @@ def cmd_verify(args): f"top-{args.top_k_depth} {len(probe_queries) - len(rank_failures)}/" f"{len(probe_queries)} match") if rank_failures: - failures.append("probe-query top-K rank mismatch:\n" + "\n".join(rank_failures)) + failures.append( + "probe-query top-K rank mismatch:\n" + + "\n".join(rank_failures)) if failures: print("\n[verify] FAILED:") @@ -252,17 +257,24 @@ def main(): formatter_class=argparse.RawDescriptionHelpFormatter) sub = parser.add_subparsers(dest="cmd", required=True) - pw = sub.add_parser("write", help="Generate a reference manifest from a DB.") + pw = sub.add_parser( + "write", + help="Generate a reference manifest from a DB.") pw.add_argument("--db", required=True) pw.add_argument("--retriever_model", default="intfloat/e5-base-v2") pw.add_argument("--dataset", default="data/frames_dataset.tsv") pw.add_argument("--output", required=True) pw.set_defaults(func=cmd_write) - pv = sub.add_parser("verify", help="Verify a DB against a reference manifest.") + pv = sub.add_parser( + "verify", + help="Verify a DB against a reference manifest.") pv.add_argument("--db", required=True) pv.add_argument("--manifest", required=True) - pv.add_argument("--cosine-threshold", type=float, default=DEFAULT_COSINE_THRESHOLD) + pv.add_argument( + "--cosine-threshold", + type=float, + default=DEFAULT_COSINE_THRESHOLD) pv.add_argument("--top-k-depth", type=int, default=DEFAULT_TOP_K_DEPTH) pv.set_defaults(func=cmd_verify) diff --git a/e2e-rag/download_docs.py b/e2e-rag/download_docs.py index 1778e97e21..f5d1ad46e8 100644 --- a/e2e-rag/download_docs.py +++ b/e2e-rag/download_docs.py @@ -78,56 +78,61 @@ def fix_malformed_url(url: str) -> str: return url - - - class BaseDownloader(ABC): """Base class for downloading web pages in different formats.""" - + def __init__(self, output_dir: str, processes: int = 10): self.output_dir = Path(output_dir) self.output_dir.mkdir(exist_ok=True) self.processes = processes self.url_mapping = {} - + @abstractmethod def get_file_extension(self) -> str: """Return the file extension for this downloader.""" pass - + def create_filename(self, url: str) -> str: """Generate a filename from URL.""" - filename = url.replace("https://", "").replace("/", "_").replace(":", "_") - + filename = url.replace( + "https://", + "").replace( + "/", + "_").replace( + ":", + "_") + # Truncate if too long max_length = 200 if len(filename) > max_length: filename = filename[:max_length] - + return filename + self.get_file_extension() - + @abstractmethod - def download_single_url(self, url: str, output_path: Path) -> Tuple[bool, str]: + def download_single_url( + self, url: str, output_path: Path) -> Tuple[bool, str]: """ Download a single URL to the specified path. - + Returns: Tuple of (success: bool, error_message: str) """ pass - - def process_url(self, args_tuple: Tuple[str, Path, int, int]) -> Tuple[bool, str, str, str]: + + def process_url( + self, args_tuple: Tuple[str, Path, int, int]) -> Tuple[bool, str, str, str]: """Process a single URL - designed for multiprocessing.""" url, output_dir, index, total = args_tuple - + # Create safe filename filename = self.create_filename(url) output_path = output_dir / filename - + # Skip if file already exists if output_path.exists(): return True, filename, "Skipping", url - + # Download the URL try: success, error_msg = self.download_single_url(url, output_path) @@ -141,76 +146,82 @@ def process_url(self, args_tuple: Tuple[str, Path, int, int]) -> Tuple[bool, str return False, filename, error_msg, url else: return False, filename, error_msg, url - + except Exception as e: return False, filename, f"Exception: {str(e)[:100]}", url - - def download_urls(self, urls: List[str], retry_failures: bool = True) -> Dict[str, Any]: + + def download_urls( + self, urls: List[str], retry_failures: bool = True) -> Dict[str, Any]: """Download multiple URLs with parallel processing and progress tracking.""" - + if not urls: print("No URLs found to process") return {"successful": 0, "failed": 0, "failed_urls": []} - - print(f"Processing {len(urls)} URLs with {self.processes} parallel processes...") - + + print( + f"Processing {len(urls)} URLs with {self.processes} parallel processes...") + # Create progress bar progress_bar = tqdm( total=len(urls), desc="Starting downloads...", unit="URL" ) - + # Process URLs in parallel with progress bar start_time = time.time() - + # Prepare arguments for multiprocessing process_args = [(url, self.output_dir, i + 1, len(urls)) - for i, url in enumerate(urls)] - + for i, url in enumerate(urls)] + # Process with progress bar updates results = [] failed_urls = [] # Track failed URLs for detailed reporting - + with Pool(processes=self.processes) as pool: for result in pool.imap(self.process_url, process_args): success, filename, status, url = result results.append((success, filename)) - + base_filename = get_base_filename(filename) self.url_mapping[base_filename] = url - + # Update progress bar with status if status == "Skipping": - progress_bar.set_description(f"Skipping: {filename[:30]}...") + progress_bar.set_description( + f"Skipping: {filename[:30]}...") elif status == "Success": - progress_bar.set_description(f"✓ Success: {filename[:30]}...") + progress_bar.set_description( + f"✓ Success: {filename[:30]}...") else: # This is a failure case - progress_bar.set_description(f"✗ {status}: {filename[:30]}...") + progress_bar.set_description( + f"✗ {status}: {filename[:30]}...") failed_urls.append((filename, status, url)) print(f"\n❌ FAILED: {filename}") print(f" URL: {url}") print(f" Error: {status}") - + progress_bar.update(1) - + progress_bar.close() - + # Save URL mapping to JSON file self.save_url_mapping() - + # Count results successful = sum(1 for success, _ in results if success) failed = len(results) - successful - + end_time = time.time() duration = end_time - start_time - - print(f"\nDownload complete! Successful: {successful}, Failed: {failed}") + + print( + f"\nDownload complete! Successful: {successful}, Failed: {failed}") print(f"Total time: {duration:.2f} seconds") print(f"Average time per URL: {duration/len(urls):.2f} seconds") - + # Print detailed failure report if there were failures if failed_urls: print(f"\n=== FAILED DOWNLOADS DETAILS ===") @@ -219,7 +230,7 @@ def download_urls(self, urls: List[str], retry_failures: bool = True) -> Dict[st print(f" URL: {url}") print(f" Error: {status}") print() - + # Retry failed URLs if requested if retry_failures: retry_result = self.retry_failed_urls(failed_urls) @@ -227,23 +238,24 @@ def download_urls(self, urls: List[str], retry_failures: bool = True) -> Dict[st failed = retry_result["still_failed"] else: print(f"\n✅ All downloads completed successfully!") - + return { "successful": successful, "failed": failed, "failed_urls": failed_urls, "duration": duration } - + def save_url_mapping(self): """Save URL mapping to JSON file.""" save_url_mapping(str(self.output_dir), self.url_mapping) - - def retry_failed_urls(self, failed_urls: List[Tuple[str, str, str]]) -> Dict[str, int]: + + def retry_failed_urls( + self, failed_urls: List[Tuple[str, str, str]]) -> Dict[str, int]: """Retry downloading failed URLs, but skip certain types of permanent failures.""" if not failed_urls: return {"successful": 0, "still_failed": 0} - + # Filter out failures that shouldn't be retried (permanent failures). # Note: We intentionally do NOT include "HTTP error 4" here because that # would also match retryable 429 (Too Many Requests) responses. @@ -252,62 +264,72 @@ def retry_failed_urls(self, failed_urls: List[Tuple[str, str, str]]) -> Dict[str "HTTP error 410", "HTTP error 451", "Invalid or empty content", ] - + retryable_urls = [] permanent_failures = [] - + for filename, status, url in failed_urls: - is_permanent = any(keyword in status for keyword in permanent_failure_keywords) + is_permanent = any( + keyword in status for keyword in permanent_failure_keywords) if is_permanent: permanent_failures.append((filename, status, url)) else: retryable_urls.append((filename, status, url)) - + if permanent_failures: - print(f"\nSkipping {len(permanent_failures)} permanent failures (404s, etc.)") - + print( + f"\nSkipping {len(permanent_failures)} permanent failures (404s, etc.)") + if not retryable_urls: print("No retryable URLs found.") return {"successful": 0, "still_failed": len(permanent_failures)} - + print(f"\n=== RETRYING FAILED DOWNLOADS ===") - print(f"Retrying {len(retryable_urls)} failed URLs (skipping {len(permanent_failures)} permanent failures)...") - + print( + f"Retrying {len(retryable_urls)} failed URLs (skipping {len(permanent_failures)} permanent failures)...") + # Prepare arguments for retry retry_args = [(url, self.output_dir, i + 1, len(retryable_urls)) - for i, (_, _, url) in enumerate(retryable_urls)] - + for i, (_, _, url) in enumerate(retryable_urls)] + # Create progress bar for retry - progress_bar = tqdm(total=len(retryable_urls), desc="Retrying...", unit="URL") - + progress_bar = tqdm( + total=len(retryable_urls), + desc="Retrying...", + unit="URL") + successful_retries = 0 still_failed = [] - + with Pool(processes=self.processes) as pool: for result in pool.imap(self.process_url, retry_args): success, filename, status, url = result - + if success: - progress_bar.set_description(f"✓ Retry Success: {filename[:30]}...") + progress_bar.set_description( + f"✓ Retry Success: {filename[:30]}...") successful_retries += 1 else: - progress_bar.set_description(f"✗ Retry Failed: {filename[:30]}...") + progress_bar.set_description( + f"✗ Retry Failed: {filename[:30]}...") still_failed.append((filename, status, url)) print(f"\n❌ RETRY FAILED: {filename}") print(f" URL: {url}") print(f" Error: {status}") - + progress_bar.update(1) - + progress_bar.close() - + # Combine still failed with permanent failures all_failed = still_failed + permanent_failures - - print(f"\nRetry complete! Successfully retried: {successful_retries}, Still failed: {len(all_failed)}") - print(f" - Retryable failures: {len(still_failed)}") - print(f" - Permanent failures (404s, etc.): {len(permanent_failures)}") - + + print( + f"\nRetry complete! Successfully retried: {successful_retries}, Still failed: {len(all_failed)}") + print(f" - Retryable failures: {len(still_failed)}") + print( + f" - Permanent failures (404s, etc.): {len(permanent_failures)}") + if all_failed: print(f"\n=== STILL FAILED AFTER RETRY ===") for i, (filename, status, url) in enumerate(all_failed, 1): @@ -315,18 +337,19 @@ def retry_failed_urls(self, failed_urls: List[Tuple[str, str, str]]) -> Dict[str print(f" URL: {url}") print(f" Error: {status}") print() - - return {"successful": successful_retries, "still_failed": len(all_failed)} + + return {"successful": successful_retries, + "still_failed": len(all_failed)} class PDFDownloader(BaseDownloader): """Download web pages as PDFs using wkhtmltopdf.""" - + def get_file_extension(self) -> str: return ".pdf" - - - def download_single_url(self, url: str, output_path: Path) -> Tuple[bool, str]: + + def download_single_url( + self, url: str, output_path: Path) -> Tuple[bool, str]: """Download a single URL as PDF using wkhtmltopdf with try-fix-retry approach.""" def attempt_pdf_download(target_url: str) -> Tuple[bool, str, bool]: """ @@ -339,7 +362,7 @@ def attempt_pdf_download(target_url: str) -> Tuple[bool, str, bool]: f'--load-error-handling ignore --load-media-error-handling ignore ' f'--javascript-delay 2000 "{target_url}" "{output_path}"' ) - + try: result = subprocess.run( command, @@ -348,7 +371,7 @@ def attempt_pdf_download(target_url: str) -> Tuple[bool, str, bool]: text=True, timeout=120 ) - + if result.returncode == 0: return True, "Success", False else: @@ -357,20 +380,20 @@ def attempt_pdf_download(target_url: str) -> Tuple[bool, str, bool]: stderr_text = result.stderr[:200] error_msg += f" - {stderr_text}" # Check for 404-like errors in stderr - is_404 = any(phrase in stderr_text.lower() for phrase in - ['404', 'not found', 'page not found', 'http error']) + is_404 = any(phrase in stderr_text.lower() for phrase in + ['404', 'not found', 'page not found', 'http error']) return False, error_msg, is_404 return False, error_msg, False - + except subprocess.TimeoutExpired: return False, "Timeout (120s)", False - + # Try original URL first success, error_msg, is_404 = attempt_pdf_download(url) - + if success: return True, "Success" - + # If it was a 404-like error, try to fix the URL and retry if is_404: fixed_url = fix_malformed_url(url) @@ -381,30 +404,32 @@ def attempt_pdf_download(target_url: str) -> Tuple[bool, str, bool]: return True, f"Success (fixed URL)" else: return False, f"Original: {error_msg}; Fixed attempt: {retry_error_msg}" - + # Return original error if no fix was attempted or fix failed return False, error_msg class HTMLDownloader(BaseDownloader): """Download web pages as HTML files using requests.""" - + # Per-process retry policy for transient errors (429, 5xx) MAX_RETRIES = 5 BASE_BACKOFF = 2.0 # seconds; exponential: BASE_BACKOFF * 2**attempt MAX_BACKOFF = 60.0 # cap on a single sleep - def __init__(self, output_dir: str, processes: int = 4, delay: float = 1.0, timeout: int = 30): + def __init__(self, output_dir: str, processes: int = 4, + delay: float = 1.0, timeout: int = 30): # HTML downloading can use parallel processes with rate limiting super().__init__(output_dir, processes=processes) self.delay = delay self.timeout = timeout - + if requests is None: - raise ImportError("requests package is required for HTML downloads. Install with: pip install requests") - + raise ImportError( + "requests package is required for HTML downloads. Install with: pip install requests") + self.session = requests.Session() - + # Wikipedia's User-Agent policy asks for a descriptive UA that identifies # the tool/operator and includes a contact URL or email. Generic browser # UAs are aggressively rate-limited. Operators may override via the @@ -420,12 +445,12 @@ def __init__(self, output_dir: str, processes: int = 4, delay: float = 1.0, time 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8', 'Accept-Language': 'en-US,en;q=0.9', }) - + def get_file_extension(self) -> str: return ".html" - - - def _compute_backoff(self, attempt: int, retry_after_header: Optional[str]) -> float: + + def _compute_backoff(self, attempt: int, + retry_after_header: Optional[str]) -> float: """Compute backoff delay honoring Retry-After if present, with jitter.""" if retry_after_header: try: @@ -437,7 +462,8 @@ def _compute_backoff(self, attempt: int, retry_after_header: Optional[str]) -> f cap = min(self.BASE_BACKOFF * (2 ** attempt), self.MAX_BACKOFF) return random.uniform(self.BASE_BACKOFF, cap) - def download_single_url(self, url: str, output_path: Path) -> Tuple[bool, str]: + def download_single_url( + self, url: str, output_path: Path) -> Tuple[bool, str]: """Download a single URL as HTML using requests with try-fix-retry approach.""" def attempt_download(target_url: str) -> Tuple[bool, str, bool]: """ @@ -447,7 +473,8 @@ def attempt_download(target_url: str) -> Tuple[bool, str, bool]: last_error = "Unknown error" for attempt in range(self.MAX_RETRIES): try: - response = self.session.get(target_url, timeout=self.timeout) + response = self.session.get( + target_url, timeout=self.timeout) # Retry on 429 (rate limit) and 5xx (server errors) if response.status_code == 429 or 500 <= response.status_code < 600: @@ -472,14 +499,16 @@ def attempt_download(target_url: str) -> Tuple[bool, str, bool]: # Check if we got redirected to a different article if response.url != target_url: - print(f"Redirected from {target_url} to {response.url}") + print( + f"Redirected from {target_url} to {response.url}") if 'charset' in response.headers.get('content-type', ''): encoding = response.encoding else: encoding = 'utf-8' - html_content = response.content.decode(encoding, errors='ignore') + html_content = response.content.decode( + encoding, errors='ignore') if 'wikipedia' not in html_content.lower() or len(html_content) < 1000: return False, f"Invalid or empty content (length: {len(html_content)})", False @@ -514,13 +543,13 @@ def attempt_download(target_url: str) -> Tuple[bool, str, bool]: return False, f"Exception: {str(e)[:100]}", False return False, last_error, False - + # Try original URL first success, error_msg, is_404 = attempt_download(url) - + if success: return True, "Success" - + # If it was a 404-like error, try to fix the URL and retry if is_404: fixed_url = fix_malformed_url(url) @@ -531,21 +560,15 @@ def attempt_download(target_url: str) -> Tuple[bool, str, bool]: return True, f"Success (fixed URL)" else: return False, f"Original: {error_msg}; Fixed attempt: {retry_error_msg}" - + # Return original error if no fix was attempted or fix failed return False, error_msg - + def download_urls(self, urls: List[str], retry_failures: bool = True): """Download URLs with parallel processing and rate limiting.""" return super().download_urls(urls, retry_failures) - - - - - - def download_frames_dataset(output_dir): """Download the FRAMES dataset from Hugging Face and save as TSV.""" print("Downloading FRAMES dataset from Hugging Face...") @@ -623,7 +646,8 @@ def extract_wikipedia_links(item): return links -def extract_urls_from_frames_dataset(tsv_path: str, max_urls: Optional[int] = None) -> List[str]: +def extract_urls_from_frames_dataset( + tsv_path: str, max_urls: Optional[int] = None) -> List[str]: """Extract unique Wikipedia URLs from FRAMES dataset TSV file.""" # Load dataset df = pd.read_csv(tsv_path, sep='\t') @@ -635,15 +659,12 @@ def extract_urls_from_frames_dataset(tsv_path: str, max_urls: Optional[int] = No urls.add(link) urls = list(urls) - + # Limit number if specified if max_urls and max_urls > 0: urls = urls[:max_urls] - - return urls - - + return urls def main(): @@ -651,15 +672,15 @@ def main(): description='Download Wikipedia pages as PDFs or HTML files from FRAMES dataset or other sources.\nBy default, URLs are validated before downloading to avoid 404 errors.', formatter_class=argparse.RawTextHelpFormatter ) - + # Format selection parser.add_argument( - '--format', - choices=['pdf', 'html'], + '--format', + choices=['pdf', 'html'], default='pdf', help='Output format: pdf or html (default: pdf)' ) - + # URL sources (mutually exclusive) url_group = parser.add_mutually_exclusive_group() url_group.add_argument( @@ -668,15 +689,15 @@ def main(): ) url_group.add_argument( - '--urls', - nargs='+', + '--urls', + nargs='+', help='List of URLs to download' ) url_group.add_argument( - '--url-file', + '--url-file', help='File containing URLs (one per line)' ) - + # Output options parser.add_argument( '--output-dir', @@ -687,7 +708,7 @@ def main(): default='frames-benchmark-dataset', help='Directory for dataset files (default: frames-benchmark-dataset)' ) - + # Processing options parser.add_argument( '--max-files', @@ -702,7 +723,7 @@ def main(): 'Wikipedia rate-limits aggressive HTML scrapers; ' 'consider --processes 4 for HTML.' ) - + # HTML-specific options parser.add_argument( '--delay', @@ -717,7 +738,7 @@ def main(): default=30, help='Timeout for HTML requests in seconds (default: 30)' ) - + # Dataset options parser.add_argument( '--download-dataset', @@ -733,7 +754,7 @@ def main(): # Get URLs from various sources urls = None - + if args.url_file: try: with open(args.url_file, 'r', encoding='utf-8') as f: @@ -742,14 +763,14 @@ def main(): except Exception as e: print(f"Error loading URLs from file: {e}") return - + elif args.urls: urls = args.urls - + elif args.tsv_path or args.download_dataset: # Handle FRAMES dataset tsv_path = args.tsv_path - + if tsv_path is None and args.download_dataset: print("=== DOWNLOADING FRAMES DATASET ===") tsv_path = download_frames_dataset(args.data_dir) @@ -759,10 +780,10 @@ def main(): elif tsv_path is None: print("❌ No TSV path provided. Use --tsv-path or --download-dataset") return - + urls = extract_urls_from_frames_dataset(tsv_path, args.max_files) print(f"Extracted {len(urls)} URLs from FRAMES dataset: {tsv_path}") - + else: # Default: download dataset print("=== DOWNLOADING FRAMES DATASET ===") @@ -770,10 +791,10 @@ def main(): if tsv_path is None: print("❌ Failed to download FRAMES dataset. Exiting.") return - + urls = extract_urls_from_frames_dataset(tsv_path, args.max_files) print(f"Extracted {len(urls)} URLs from FRAMES dataset: {tsv_path}") - + if not urls: print("No URLs found to download") return @@ -784,21 +805,22 @@ def main(): user = getpass.getuser() xdg_runtime_dir = f"/tmp/runtime-{user}" os.environ["XDG_RUNTIME_DIR"] = xdg_runtime_dir - + downloader = PDFDownloader(args.output_dir, args.processes) - + elif args.format == 'html': if requests is None: - print("❌ HTML format requires 'requests' package. Install with: pip install requests") + print( + "❌ HTML format requires 'requests' package. Install with: pip install requests") return - + downloader = HTMLDownloader( output_dir=args.output_dir, processes=args.processes, delay=args.delay, timeout=args.timeout ) - + else: print(f"❌ Unsupported format: {args.format}") return diff --git a/e2e-rag/evaluate.py b/e2e-rag/evaluate.py index 9c94ce8890..b963b98124 100644 --- a/e2e-rag/evaluate.py +++ b/e2e-rag/evaluate.py @@ -51,16 +51,18 @@ def load_results(path: Path): # Load pandas DataFrame checkpoint with open(path, 'rb') as f: df = pickle.load(f) - + # Convert DataFrame to dict: query -> llm_answer # Only include successfully completed queries successful = df[df['success'] == True] - return {row['query']: row['llm_answer'] for _, row in successful.iterrows()} + return {row['query']: row['llm_answer'] + for _, row in successful.iterrows()} else: # Legacy JSON format data = json.loads(path.read_text(encoding="utf-8")) results = data.get("results", []) - return {entry.get("prompt"): entry.get("llm_answer", "") for entry in results if entry.get("prompt")} + return {entry.get("prompt"): entry.get("llm_answer", "") + for entry in results if entry.get("prompt")} def _parse_score_value(value) -> int: @@ -122,7 +124,8 @@ def _extract_json_dict(content: str) -> Optional[dict]: return None -def call_judge(session: requests.Session, service_url: str, model: str, question: str, gold: str, pred: str): +def call_judge(session: requests.Session, service_url: str, + model: str, question: str, gold: str, pred: str): prompt = ( "You judge whether the model answer correctly answers the question based on semantic equivalence to the gold answer.\n\n" "GRADING RULES:\n" @@ -168,12 +171,17 @@ def call_judge(session: requests.Session, service_url: str, model: str, question "X-Title": "RAG-QnA Evaluation" } - response = session.post(service_url, json=payload, headers=headers, timeout=120) + response = session.post( + service_url, + json=payload, + headers=headers, + timeout=120) response.raise_for_status() data = response.json() # Defensive: handle missing or malformed 'choices' in response choices = data.get("choices") - if not choices or not isinstance(choices, list) or not choices[0] or "message" not in choices[0] or "content" not in choices[0]["message"]: + if not choices or not isinstance( + choices, list) or not choices[0] or "message" not in choices[0] or "content" not in choices[0]["message"]: print("[ERROR] Judge response missing 'choices' or 'content':", data) # Return score 0, explanation with raw response, and raw data return 0, f"Malformed judge response: {data}", str(data) @@ -207,11 +215,13 @@ def call_judge(session: requests.Session, service_url: str, model: str, question def _judge_row(idx, prompt, gold, pred, service_url, model): """Call judge for a single row, returning (idx, prompt, gold, pred, score, explanation, raw).""" session = requests.Session() - score, explanation, raw = call_judge(session, service_url, model, prompt, gold, pred) + score, explanation, raw = call_judge( + session, service_url, model, prompt, gold, pred) return idx, prompt, gold, pred, score, explanation, raw -def evaluate(results_path: Path, dataset_path: Path, service_url: str, model: str, batch_size: int = 16): +def evaluate(results_path: Path, dataset_path: Path, + service_url: str, model: str, batch_size: int = 16): # Check for OpenRouter API key if using OpenRouter if "openrouter.ai" in service_url and not OPENROUTER_API_KEY: print("ERROR: OPENROUTER_API_KEY environment variable not set") @@ -227,13 +237,15 @@ def evaluate(results_path: Path, dataset_path: Path, service_url: str, model: st print(f"CHECKPOINT STATISTICS") print("=" * 80) print(f"Total queries in checkpoint: {len(checkpoint_df)}") - print(f"Successful queries: {(checkpoint_df['success'] == True).sum()}") + print( + f"Successful queries: {(checkpoint_df['success'] == True).sum()}") print(f"Failed queries: {(checkpoint_df['success'] == False).sum()}") if 'num_docs' in checkpoint_df.columns: total_docs = checkpoint_df['num_docs'].sum() total_missing = checkpoint_df['num_missing_docs'].sum() print(f"Total documents referenced: {total_docs}") - print(f"Missing documents: {total_missing} ({100*total_missing/total_docs:.2f}%)") + print( + f"Missing documents: {total_missing} ({100*total_missing/total_docs:.2f}%)") print("=" * 80) print() @@ -256,7 +268,8 @@ def evaluate(results_path: Path, dataset_path: Path, service_url: str, model: st total = len(items) unknown = sum(1 for _, _, _, pred in items if pred.lower() == "unknown") - # Submit all judge calls in parallel (batch_size workers), print as they complete + # Submit all judge calls in parallel (batch_size workers), print as they + # complete score_sum = 0 with ThreadPoolExecutor(max_workers=batch_size) as executor: futures = { @@ -284,15 +297,38 @@ def evaluate(results_path: Path, dataset_path: Path, service_url: str, model: st def parse_args(): - parser = argparse.ArgumentParser(description="Evaluate single-shot results using an LLM judge.") - parser.add_argument("results", type=Path, help="Path to results (result_single_shot.json or oracle_checkpoint.pkl)") - parser.add_argument("--dataset", type=Path, default=Path("data/frames_dataset.tsv"), help="Evaluation dataset TSV") - parser.add_argument("--judge-url", default=DEFAULT_JUDGE_URL, help="Judge service endpoint") - parser.add_argument("--judge-model", default=DEFAULT_JUDGE_MODEL, help="Judge model identifier") - parser.add_argument("--batch-size", type=int, default=16, help="Number of concurrent judge requests (default: 16)") + parser = argparse.ArgumentParser( + description="Evaluate single-shot results using an LLM judge.") + parser.add_argument( + "results", + type=Path, + help="Path to results (result_single_shot.json or oracle_checkpoint.pkl)") + parser.add_argument( + "--dataset", + type=Path, + default=Path("data/frames_dataset.tsv"), + help="Evaluation dataset TSV") + parser.add_argument( + "--judge-url", + default=DEFAULT_JUDGE_URL, + help="Judge service endpoint") + parser.add_argument( + "--judge-model", + default=DEFAULT_JUDGE_MODEL, + help="Judge model identifier") + parser.add_argument( + "--batch-size", + type=int, + default=16, + help="Number of concurrent judge requests (default: 16)") return parser.parse_args() if __name__ == "__main__": args = parse_args() - evaluate(args.results, args.dataset, args.judge_url, args.judge_model, args.batch_size) + evaluate( + args.results, + args.dataset, + args.judge_url, + args.judge_model, + args.batch_size) diff --git a/e2e-rag/evaluation.py b/e2e-rag/evaluation.py index a6c59a889a..8ed81824fc 100644 --- a/e2e-rag/evaluation.py +++ b/e2e-rag/evaluation.py @@ -32,82 +32,85 @@ from utils import filter_dataset_by_difficulty -def calculate_retrieval_metrics(expected_urls: List[str], retrieved_urls: List[str], k_values: List[int] = [1, 3, 5, 10]) -> Dict[str, float]: +def calculate_retrieval_metrics(expected_urls: List[str], retrieved_urls: List[str], k_values: List[int] = [ + 1, 3, 5, 10]) -> Dict[str, float]: """ Calculate comprehensive retrieval metrics. - + Args: expected_urls: List of expected/ground truth URLs retrieved_urls: List of retrieved URLs in ranking order k_values: List of k values for Precision@k, Recall@k, F1@k - + Returns: Dictionary containing all calculated metrics """ expected_set = set(url for url in expected_urls if url and url.strip()) - + # Handle edge cases if not expected_set: return {f'precision@{k}': 1.0 if len(retrieved_urls) == 0 else 0.0 for k in k_values} | \ {f'recall@{k}': 1.0 for k in k_values} | \ {f'f1@{k}': 1.0 if len(retrieved_urls) == 0 else 0.0 for k in k_values} | \ {'average_precision': 1.0 if len(retrieved_urls) == 0 else 0.0} - + metrics = {} - - # Calculate metrics for different k values, including @N (actual retrieved count) + + # Calculate metrics for different k values, including @N (actual retrieved + # count) num_retrieved = len(retrieved_urls) num_expected = len(expected_set) k_values_with_n = k_values + [num_retrieved] # Add N to k_values - + for k in k_values_with_n: # Determine the label (use 'N' for the actual retrieved count) k_label = 'N' if k == num_retrieved else str(k) - + # Get top k documents top_k = retrieved_urls[:k] top_k_set = set(top_k) relevant_retrieved = len(expected_set.intersection(top_k_set)) - + # Precision@k: fraction of retrieved documents that are relevant precision_k = relevant_retrieved / k if k > 0 else 0.0 metrics[f'precision@{k_label}'] = precision_k - + # Recall@k: fraction of relevant documents that are retrieved recall_k = relevant_retrieved / num_expected if num_expected > 0 else 0.0 metrics[f'recall@{k_label}'] = recall_k - + # F1@k: harmonic mean of precision and recall if precision_k + recall_k > 0: f1_k = 2 * (precision_k * recall_k) / (precision_k + recall_k) else: f1_k = 0.0 metrics[f'f1@{k_label}'] = f1_k - + # Mean Average Precision (MAP) - considers ranking order ap_sum = 0.0 relevant_found = 0 - + for i, url in enumerate(retrieved_urls): if url in expected_set: relevant_found += 1 precision_at_i = relevant_found / (i + 1) ap_sum += precision_at_i - - average_precision = ap_sum / len(expected_set) if len(expected_set) > 0 else 0.0 + + average_precision = ap_sum / \ + len(expected_set) if len(expected_set) > 0 else 0.0 metrics['average_precision'] = average_precision - + return metrics -def evaluate_retrieval_query(rag_db, query: str, expected_urls: List[str], - top_k_retriever: int = 50, top_k_reranking: int = 10, - verbose: bool = True, no_rerank: bool = False, - retrieval_strategy: str = "fixed_k", print_results: bool = False, - return_results: bool = False, **strategy_params) -> Union[Dict[str, Any], Tuple[Dict[str, Any], List[Any]]]: +def evaluate_retrieval_query(rag_db, query: str, expected_urls: List[str], + top_k_retriever: int = 50, top_k_reranking: int = 10, + verbose: bool = True, no_rerank: bool = False, + retrieval_strategy: str = "fixed_k", print_results: bool = False, + return_results: bool = False, **strategy_params) -> Union[Dict[str, Any], Tuple[Dict[str, Any], List[Any]]]: """ Evaluate a single retrieval query and return comprehensive retrieval metrics. - + Args: rag_db: RAG database instance query: Query string @@ -118,12 +121,12 @@ def evaluate_retrieval_query(rag_db, query: str, expected_urls: List[str], no_rerank: Skip reranking step for fair comparison between retrieval methods retrieval_strategy: Strategy for retrieval ("fixed_k", "top_p", "relative") **strategy_params: Parameters for adaptive retrieval strategies - + Returns: Dictionary containing all metrics. When return_results=True, returns a tuple of (metrics_dict, retrieved_results). """ import time - + # Step 1: Time the initial retrieval retrieval_start = time.perf_counter() if retrieval_strategy == "fixed_k": @@ -131,23 +134,25 @@ def evaluate_retrieval_query(rag_db, query: str, expected_urls: List[str], else: from retrieve.filter import filter max_results = strategy_params.pop("max_results", 20) - results = filter(rag_db, query, method=retrieval_strategy, - max_results=max_results, **strategy_params) + results = filter(rag_db, query, method=retrieval_strategy, + max_results=max_results, **strategy_params) retrieval_time = time.perf_counter() - retrieval_start - - # Step 2: Apply reranking if enabled and reranker is available + + # Step 2: Apply reranking if enabled and reranker is available reranking_time = 0.0 - if not no_rerank and hasattr(rag_db, '_reranker_model') and rag_db._reranker_model is not None: + if not no_rerank and hasattr( + rag_db, '_reranker_model') and rag_db._reranker_model is not None: # Safety check: If no results retrieved, skip reranking if not results: if verbose: - print(f"Warning: No documents retrieved for query: {query[:50]}") + print( + f"Warning: No documents retrieved for query: {query[:50]}") else: reranking_start = time.perf_counter() # Extract text content for reranking (rerank expects strings) passages = [result.page_content for result in results] scored_passages = rag_db.rerank(query, passages) - + # Reconstruct document objects with reranked order # scored_passages is [(text, score), ...] ordered by score reranked_results = [] @@ -157,39 +162,45 @@ def evaluate_retrieval_query(rag_db, query: str, expected_urls: List[str], if doc.page_content == text: reranked_results.append(doc) break - + # Apply top_k_reranking limit AFTER reranking # For adaptive strategies (top_p, relative, etc.), respect the number of documents # selected by the strategy, only limit for fixed_k if retrieval_strategy == "fixed_k": results = reranked_results[:top_k_reranking] else: - # For adaptive strategies, keep all documents selected by the strategy + # For adaptive strategies, keep all documents selected by the + # strategy results = reranked_results reranking_time = time.perf_counter() - reranking_start - + # Extract URLs from results in order (maintaining ranking) retrieved_urls = [] for result in results: if 'original_url' in result.metadata and result.metadata['original_url']: retrieved_urls.append(result.metadata['original_url']) - # Deduplicate URLs preserving first appearance order (for accurate MAP calculation) - deduplicated_urls = list(dict.fromkeys(retrieved_urls)) # Preserves order, removes duplicates - + # Deduplicate URLs preserving first appearance order (for accurate MAP + # calculation) + # Preserves order, removes duplicates + deduplicated_urls = list(dict.fromkeys(retrieved_urls)) + # Calculate comprehensive metrics using deduplicated URLs (accurate MAP) expected_set = set(url for url in expected_urls if url and url.strip()) - metrics = calculate_retrieval_metrics(list(expected_set), deduplicated_urls) - + metrics = calculate_retrieval_metrics( + list(expected_set), deduplicated_urls) + # Track both passages and unique documents num_passages = len(results) num_unique_docs = len(deduplicated_urls) - + if verbose: print(f"Query: {query:50}") matches = len(expected_set.intersection(set(deduplicated_urls))) - print(f"Expected ({len(expected_set)}): {sorted(list(expected_set)[:3])}{'...' if len(expected_set) > 3 else ''}") - print(f"Retrieved ({num_passages} passages, {num_unique_docs} unique docs): {deduplicated_urls[:3]}{'...' if num_unique_docs > 3 else ''}") + print( + f"Expected ({len(expected_set)}): {sorted(list(expected_set)[:3])}{'...' if len(expected_set) > 3 else ''}") + print( + f"Retrieved ({num_passages} passages, {num_unique_docs} unique docs): {deduplicated_urls[:3]}{'...' if num_unique_docs > 3 else ''}") print(f"Matches: {matches}") metric_categories = [ @@ -208,11 +219,12 @@ def evaluate_retrieval_query(rag_db, query: str, expected_urls: List[str], print(f"MAP: {metrics['average_precision']:.3f}") print("-" * 80) - + # Print detailed results for single query mode if print_results: - print(f"\n{retrieval_strategy.upper()} lookup took time. {len(results)} results found:") - + print( + f"\n{retrieval_strategy.upper()} lookup took time. {len(results)} results found:") + # Display which PDFs the passages are from for i, result in enumerate(results, 1): print(f"{i}. {result.metadata}") @@ -232,22 +244,23 @@ def evaluate_retrieval_query(rag_db, query: str, expected_urls: List[str], print("No reranker used (--no-rerank specified)") else: print("No reranker available - showing retrieval results only") - + # Calculate retrieval performance metrics total_time = retrieval_time + reranking_time docs_per_second = len(results) / total_time if total_time > 0 else 0 - + # Add retrieval performance to metrics retrieval_metrics = { 'retrieval_time': retrieval_time, - 'reranking_time': reranking_time, + 'reranking_time': reranking_time, 'total_retrieval_time': total_time, 'retrieved_passages_count': len(results), 'retrieved_docs_count': num_unique_docs, 'docs_per_second': docs_per_second } - - # Print retrieval performance if in benchmark mode and single query mode (not evaluation) + + # Print retrieval performance if in benchmark mode and single query mode + # (not evaluation) if hasattr(rag_db, '_benchmark') and rag_db._benchmark and print_results: print(f"\n🔍 RETRIEVAL PERFORMANCE METRICS") print("=" * 50) @@ -260,7 +273,7 @@ def evaluate_retrieval_query(rag_db, query: str, expected_urls: List[str], print(f"🚀 Retrieval speed: {docs_per_second:.1f} docs/sec") print(f"💾 Time per query: {total_time:.4f}s") print() - + # Return metrics dict with retrieval performance merged_metrics = {**metrics, **retrieval_metrics} if return_results: @@ -268,21 +281,22 @@ def evaluate_retrieval_query(rag_db, query: str, expected_urls: List[str], return merged_metrics -def run_evaluation(rag_db, dataset_path: str, - top_k_retriever: int = 50, top_k_reranking: int = 10, - max_queries: Optional[int] = None, no_rerank: bool = False, - retrieval_strategy: str = "fixed_k", detailed_analysis: bool = False, - difficulty: int = 0, collect_results: bool = False, - result_handler: Optional[Callable[[str, List[Any], Dict[str, Any]], Optional[Any]]] = None, - **strategy_params) -> Union[Dict[str, float], Tuple[Dict[str, float], List[Dict[str, Any]]]]: +def run_evaluation(rag_db, dataset_path: str, + top_k_retriever: int = 50, top_k_reranking: int = 10, + max_queries: Optional[int] = None, no_rerank: bool = False, + retrieval_strategy: str = "fixed_k", detailed_analysis: bool = False, + difficulty: int = 0, collect_results: bool = False, + result_handler: Optional[Callable[[ + str, List[Any], Dict[str, Any]], Optional[Any]]] = None, + **strategy_params) -> Union[Dict[str, float], Tuple[Dict[str, float], List[Dict[str, Any]]]]: """ Run comprehensive evaluation on a dataset with detailed metrics reporting. - + Args: rag_db: RAG database instance dataset_path: Path to the dataset TSV file top_k_retriever: Number of documents to retrieve initially - top_k_reranking: Number of documents after reranking + top_k_reranking: Number of documents after reranking max_queries: Maximum number of queries to evaluate (None = all) no_rerank: Skip reranking step for fair comparison between retrieval methods retrieval_strategy: Strategy for retrieval ("fixed_k", "top_p", "relative") @@ -291,15 +305,15 @@ def run_evaluation(rag_db, dataset_path: str, collect_results: If True, also collect retrieval outputs for each query result_handler: Optional callback invoked per query with (prompt, retrieved_docs, metrics) **strategy_params: Parameters for adaptive retrieval strategies - + Returns: Dictionary of averaged metrics across all queries. When collect_results=True, returns a tuple of (metrics_dict, collected_results). """ df = pd.read_csv(dataset_path, sep='\t') - + # Filter by difficulty if specified df = filter_dataset_by_difficulty(df, difficulty) - + # Limit number of queries if specified if isinstance(max_queries, int) and max_queries > 0: df = df.head(max_queries) @@ -307,7 +321,7 @@ def run_evaluation(rag_db, dataset_path: str, max_queries = len(df) print(f"\nRunning evaluation on {max_queries} queries from dataset") - + # Aggregate metrics collection total_metrics = {} all_query_metrics = [] # Store individual query metrics for detailed analysis @@ -317,19 +331,19 @@ def run_evaluation(rag_db, dataset_path: str, docs_per_sec_list = [] collected_queries = [] if collect_results else None valid_queries = 0 - + for idx, row in df.iterrows(): # Extract expected Wikipedia links expected_urls = [] for col in df.columns: if col.startswith('wikipedia_link_') and pd.notna(row[col]): expected_urls.append(row[col].strip()) - + if expected_urls: # Get comprehensive metrics for this query need_results = collect_results or (result_handler is not None) metrics_output = evaluate_retrieval_query( - rag_db, row['Prompt'], expected_urls, + rag_db, row['Prompt'], expected_urls, top_k_retriever, top_k_reranking, verbose=True, no_rerank=no_rerank, retrieval_strategy=retrieval_strategy, return_results=need_results, **strategy_params @@ -346,7 +360,8 @@ def run_evaluation(rag_db, dataset_path: str, for doc in retrieved_docs: url = None if hasattr(doc, 'metadata'): - url = doc.metadata.get('original_url') or doc.metadata.get('source') + url = doc.metadata.get( + 'original_url') or doc.metadata.get('source') content = doc.page_content elif isinstance(doc, dict): url = doc.get('url') @@ -369,59 +384,73 @@ def run_evaluation(rag_db, dataset_path: str, if result_handler: result_handler(row['Prompt'], retrieved_docs, metrics) - + # Store metrics for detailed analysis if requested if detailed_analysis: all_query_metrics.append(metrics) - + # Collect retrieval performance metrics for statistics if 'retrieval_time' in metrics: retrieval_times.append(metrics['retrieval_time']) reranking_times.append(metrics['reranking_time']) total_times.append(metrics['total_retrieval_time']) docs_per_sec_list.append(metrics['docs_per_second']) - + # Accumulate metrics for metric_name, value in metrics.items(): if metric_name not in total_metrics: total_metrics[metric_name] = 0.0 total_metrics[metric_name] += value - + valid_queries += 1 - + if valid_queries > 0: # Calculate average metrics - avg_metrics = {name: total / valid_queries for name, total in total_metrics.items()} - + avg_metrics = { + name: total / + valid_queries for name, + total in total_metrics.items()} + # Display results results_title = "OVERALL EVALUATION RESULTS" if detailed_analysis else "EVALUATION RESULTS" - print(f"\n" + "="*60) + print(f"\n" + "=" * 60) print(f"{results_title} ({valid_queries} queries)") - print(f"="*60) + print(f"=" * 60) print(f"PRECISION METRICS:") - print(f" Precision@N: {avg_metrics.get('precision@N', 0.0):.3f}") + print( + f" Precision@N: {avg_metrics.get('precision@N', 0.0):.3f}") if 'precision@1' in avg_metrics: - print(f" Precision@1: {avg_metrics['precision@1']:.3f}") + print( + f" Precision@1: {avg_metrics['precision@1']:.3f}") if 'precision@3' in avg_metrics: - print(f" Precision@3: {avg_metrics['precision@3']:.3f}") + print( + f" Precision@3: {avg_metrics['precision@3']:.3f}") if 'precision@5' in avg_metrics: - print(f" Precision@5: {avg_metrics['precision@5']:.3f}") + print( + f" Precision@5: {avg_metrics['precision@5']:.3f}") if 'precision@10' in avg_metrics: - print(f" Precision@10: {avg_metrics['precision@10']:.3f}") + print( + f" Precision@10: {avg_metrics['precision@10']:.3f}") print(f"") print(f"RECALL METRICS:") - print(f" Recall@N: {avg_metrics.get('recall@N', 0.0):.3f}") + print( + f" Recall@N: {avg_metrics.get('recall@N', 0.0):.3f}") if 'recall@1' in avg_metrics: - print(f" Recall@1: {avg_metrics['recall@1']:.3f}") + print( + f" Recall@1: {avg_metrics['recall@1']:.3f}") if 'recall@3' in avg_metrics: - print(f" Recall@3: {avg_metrics['recall@3']:.3f}") + print( + f" Recall@3: {avg_metrics['recall@3']:.3f}") if 'recall@5' in avg_metrics: - print(f" Recall@5: {avg_metrics['recall@5']:.3f}") + print( + f" Recall@5: {avg_metrics['recall@5']:.3f}") if 'recall@10' in avg_metrics: - print(f" Recall@10: {avg_metrics['recall@10']:.3f}") + print( + f" Recall@10: {avg_metrics['recall@10']:.3f}") print(f"") print(f"F1 METRICS:") - print(f" F1@N: {avg_metrics.get('f1@N', 0.0):.3f}") + print( + f" F1@N: {avg_metrics.get('f1@N', 0.0):.3f}") if 'f1@1' in avg_metrics: print(f" F1@1: {avg_metrics['f1@1']:.3f}") if 'f1@3' in avg_metrics: @@ -432,45 +461,61 @@ def run_evaluation(rag_db, dataset_path: str, print(f" F1@10: {avg_metrics['f1@10']:.3f}") print(f"") print(f"RANKING METRICS:") - print(f" Mean Average Precision: {avg_metrics['average_precision']:.3f}") + print( + f" Mean Average Precision: {avg_metrics['average_precision']:.3f}") print(f"") print(f"RETRIEVAL STATISTICS:") - print(f" Avg Passages Retrieved: {avg_metrics.get('retrieved_passages_count', 0.0):.1f}") - print(f" Avg Unique Docs (N): {avg_metrics.get('retrieved_docs_count', 0.0):.1f}") - + print( + f" Avg Passages Retrieved: {avg_metrics.get('retrieved_passages_count', 0.0):.1f}") + print( + f" Avg Unique Docs (N): {avg_metrics.get('retrieved_docs_count', 0.0):.1f}") + # Add retrieval performance statistics if we have retrieval data - if retrieval_times and hasattr(rag_db, '_benchmark') and rag_db._benchmark: + if retrieval_times and hasattr( + rag_db, '_benchmark') and rag_db._benchmark: import numpy as np - + print(f"") print(f"🔍 RETRIEVAL PERFORMANCE STATISTICS:") print(f" Retrieval Time (ms):") - print(f" Average: {np.mean(retrieval_times)*1000:.2f}ms") - print(f" P50 (Median): {np.percentile(retrieval_times, 50)*1000:.2f}ms") - print(f" P99: {np.percentile(retrieval_times, 99)*1000:.2f}ms") - + print( + f" Average: {np.mean(retrieval_times)*1000:.2f}ms") + print( + f" P50 (Median): {np.percentile(retrieval_times, 50)*1000:.2f}ms") + print( + f" P99: {np.percentile(retrieval_times, 99)*1000:.2f}ms") + if any(t > 0 for t in reranking_times): print(f" Reranking Time (ms):") - print(f" Average: {np.mean(reranking_times)*1000:.2f}ms") - print(f" P50 (Median): {np.percentile(reranking_times, 50)*1000:.2f}ms") - print(f" P99: {np.percentile(reranking_times, 99)*1000:.2f}ms") - + print( + f" Average: {np.mean(reranking_times)*1000:.2f}ms") + print( + f" P50 (Median): {np.percentile(reranking_times, 50)*1000:.2f}ms") + print( + f" P99: {np.percentile(reranking_times, 99)*1000:.2f}ms") + print(f" Total Query Time (ms):") - print(f" Average: {np.mean(total_times)*1000:.2f}ms") - print(f" P50 (Median): {np.percentile(total_times, 50)*1000:.2f}ms") - print(f" P99: {np.percentile(total_times, 99)*1000:.2f}ms") - + print( + f" Average: {np.mean(total_times)*1000:.2f}ms") + print( + f" P50 (Median): {np.percentile(total_times, 50)*1000:.2f}ms") + print( + f" P99: {np.percentile(total_times, 99)*1000:.2f}ms") + print(f" Retrieval Throughput (docs/sec):") - print(f" Average: {np.mean(docs_per_sec_list):.1f} docs/sec") - print(f" P50 (Median): {np.percentile(docs_per_sec_list, 50):.1f} docs/sec") - print(f" P99: {np.percentile(docs_per_sec_list, 99):.1f} docs/sec") - - print(f"="*60) - + print( + f" Average: {np.mean(docs_per_sec_list):.1f} docs/sec") + print( + f" P50 (Median): {np.percentile(docs_per_sec_list, 50):.1f} docs/sec") + print( + f" P99: {np.percentile(docs_per_sec_list, 99):.1f} docs/sec") + + print(f"=" * 60) + # Print detailed analysis if requested if detailed_analysis: _print_detailed_analysis(df, all_query_metrics, valid_queries) - + if collect_results: return avg_metrics, collected_queries or [] return avg_metrics @@ -481,12 +526,12 @@ def run_evaluation(rag_db, dataset_path: str, return {} -def _print_detailed_analysis(df: pd.DataFrame, all_query_metrics: List[Dict[str, Any]], - valid_queries: int) -> None: +def _print_detailed_analysis(df: pd.DataFrame, all_query_metrics: List[Dict[str, Any]], + valid_queries: int) -> None: """ Print detailed dataset analysis broken down by reasoning types and answer link counts. (Internal helper function for run_evaluation) - + Args: df: DataFrame with dataset (must have 'reasoning_types' column) all_query_metrics: List of metrics dictionaries for each query @@ -494,131 +539,132 @@ def _print_detailed_analysis(df: pd.DataFrame, all_query_metrics: List[Dict[str, """ if valid_queries == 0: return - - print("\n" + "="*80) + + print("\n" + "=" * 80) print("DETAILED DATASET ANALYSIS") - print("="*80) - + print("=" * 80) + # Prepare data - match metrics with reasoning types and link counts analysis_data = [] for idx, metrics in enumerate(all_query_metrics): if idx < len(df): row = df.iloc[idx] reasoning_types = row.get('reasoning_types', 'Unknown') - + # Count Wikipedia links - num_links = sum(1 for col in df.columns - if col.startswith('wikipedia_link_') and pd.notna(row[col])) - + num_links = sum(1 for col in df.columns + if col.startswith('wikipedia_link_') and pd.notna(row[col])) + analysis_data.append({ 'reasoning_types': reasoning_types, 'num_links': num_links, 'metrics': metrics }) - + # === ANALYSIS 1: By Reasoning Classification === - print("\n" + "-"*80) + print("\n" + "-" * 80) print("ANALYSIS BY REASONING CLASSIFICATION") - print("-"*80) - + print("-" * 80) + # Group by reasoning types reasoning_groups = defaultdict(list) for data in analysis_data: reasoning_groups[data['reasoning_types']].append(data['metrics']) - + # Calculate averages for each reasoning type reasoning_results = [] for reasoning_type, metrics_list in reasoning_groups.items(): if not metrics_list: continue - + avg_metrics = {} for key in ['precision@N', 'recall@N', 'f1@N', 'average_precision']: values = [m.get(key, 0.0) for m in metrics_list] avg_metrics[key] = sum(values) / len(values) - + reasoning_results.append({ 'type': reasoning_type, 'count': len(metrics_list), **avg_metrics }) - + # Sort by count (most common first) reasoning_results.sort(key=lambda x: x['count'], reverse=True) - + # Print top reasoning types print(f"\nTop reasoning type combinations:") print(f"{'Reasoning Type':<50} {'Count':>6} {'P@N':>6} {'R@N':>6} {'F1@N':>6} {'MAP':>6}") - print("-"*80) - + print("-" * 80) + for i, result in enumerate(reasoning_results): - rt = result['type'][:48] if len(result['type']) > 48 else result['type'] + rt = result['type'][:48] if len( + result['type']) > 48 else result['type'] print(f"{rt:<50} {result['count']:6d} " f"{result['precision@N']:6.3f} {result['recall@N']:6.3f} " f"{result['f1@N']:6.3f} {result['average_precision']:6.3f}") - + # === ANALYSIS 2: By Individual Reasoning Tags === - print(f"\n" + "-"*80) + print(f"\n" + "-" * 80) print("ANALYSIS BY INDIVIDUAL REASONING TAGS") - print("-"*80) - + print("-" * 80) + # Parse reasoning tags (split by |) tag_groups = defaultdict(list) for data in analysis_data: tags = [tag.strip() for tag in data['reasoning_types'].split('|')] for tag in tags: tag_groups[tag].append(data['metrics']) - + tag_results = [] for tag, metrics_list in tag_groups.items(): if not metrics_list: continue - + avg_metrics = {} for key in ['precision@N', 'recall@N', 'f1@N', 'average_precision']: values = [m.get(key, 0.0) for m in metrics_list] avg_metrics[key] = sum(values) / len(values) - + tag_results.append({ 'tag': tag, 'count': len(metrics_list), 'percentage': len(metrics_list) / valid_queries * 100, **avg_metrics }) - + # Sort by count tag_results.sort(key=lambda x: x['count'], reverse=True) - + print(f"\nPerformance by reasoning tag:") print(f"{'Tag':<30} {'Count':>6} {'%':>6} {'P@N':>6} {'R@N':>6} {'F1@N':>6} {'MAP':>6}") - print("-"*80) - + print("-" * 80) + for result in tag_results: tag = result['tag'][:28] if len(result['tag']) > 28 else result['tag'] print(f"{tag:<30} {result['count']:6d} {result['percentage']:5.1f}% " f"{result['precision@N']:6.3f} {result['recall@N']:6.3f} " f"{result['f1@N']:6.3f} {result['average_precision']:6.3f}") - + # === ANALYSIS 3: By Number of Answer Links === - print(f"\n" + "-"*80) + print(f"\n" + "-" * 80) print("ANALYSIS BY NUMBER OF ANSWER LINKS (Multi-hop Analysis)") - print("-"*80) - + print("-" * 80) + # Group by number of links link_groups = defaultdict(list) for data in analysis_data: link_groups[data['num_links']].append(data['metrics']) - + link_results = [] for num_links, metrics_list in link_groups.items(): if not metrics_list: continue - + avg_metrics = {} for key in ['precision@N', 'recall@N', 'f1@N', 'average_precision']: values = [m.get(key, 0.0) for m in metrics_list] avg_metrics[key] = sum(values) / len(values) - + # Classify complexity if num_links <= 2: complexity = "Simple" @@ -626,7 +672,7 @@ def _print_detailed_analysis(df: pd.DataFrame, all_query_metrics: List[Dict[str, complexity = "Multi-hop" else: complexity = "Complex" - + link_results.append({ 'num_links': num_links, 'complexity': complexity, @@ -634,25 +680,25 @@ def _print_detailed_analysis(df: pd.DataFrame, all_query_metrics: List[Dict[str, 'percentage': len(metrics_list) / valid_queries * 100, **avg_metrics }) - + # Sort by number of links link_results.sort(key=lambda x: x['num_links']) - + print(f"\nPerformance by number of Wikipedia links (reasoning hops):") print(f"{'Links':>5} {'Complexity':<12} {'Count':>6} {'%':>6} {'P@N':>6} {'R@N':>6} {'F1@N':>6} {'MAP':>6}") - print("-"*80) - + print("-" * 80) + for result in link_results: print(f"{result['num_links']:5d} {result['complexity']:<12} " f"{result['count']:6d} {result['percentage']:5.1f}% " f"{result['precision@N']:6.3f} {result['recall@N']:6.3f} " f"{result['f1@N']:6.3f} {result['average_precision']:6.3f}") - + # Summary by complexity category - print(f"\n" + "-"*80) + print(f"\n" + "-" * 80) print("SUMMARY BY COMPLEXITY LEVEL") - print("-"*80) - + print("-" * 80) + complexity_groups = defaultdict(list) for result in link_results: for _ in range(result['count']): @@ -663,49 +709,51 @@ def _print_detailed_analysis(df: pd.DataFrame, all_query_metrics: List[Dict[str, 'f1@N': result['f1@N'], 'average_precision': result['average_precision'] }) - + # Calculate totals complexity_summary = [] for complexity in ["Simple", "Multi-hop", "Complex"]: if complexity not in complexity_groups: continue - + metrics_list = complexity_groups[complexity] count = len(metrics_list) - + # Recalculate from link_results - matching_results = [r for r in link_results if r['complexity'] == complexity] + matching_results = [ + r for r in link_results if r['complexity'] == complexity] total_count = sum(r['count'] for r in matching_results) - + # Weighted average weighted_metrics = {} for key in ['precision@N', 'recall@N', 'f1@N', 'average_precision']: weighted_sum = sum(r[key] * r['count'] for r in matching_results) - weighted_metrics[key] = weighted_sum / total_count if total_count > 0 else 0.0 - + weighted_metrics[key] = weighted_sum / \ + total_count if total_count > 0 else 0.0 + complexity_summary.append({ 'complexity': complexity, 'count': total_count, 'percentage': total_count / valid_queries * 100, **weighted_metrics }) - + print(f"\n{'Complexity':<12} {'Count':>6} {'%':>6} {'P@N':>6} {'R@N':>6} {'F1@N':>6} {'MAP':>6}") - print("-"*80) - + print("-" * 80) + for result in complexity_summary: print(f"{result['complexity']:<12} {result['count']:6d} {result['percentage']:5.1f}% " f"{result['precision@N']:6.3f} {result['recall@N']:6.3f} " f"{result['f1@N']:6.3f} {result['average_precision']:6.3f}") - + # === ANALYSIS 4: Correlation Between Complexity and Reasoning Types === - print(f"\n" + "-"*80) + print(f"\n" + "-" * 80) print("CORRELATION: COMPLEXITY vs REASONING TYPES") - print("-"*80) - + print("-" * 80) + # Build correlation matrix: complexity level x reasoning tags complexity_reasoning_data = defaultdict(lambda: defaultdict(list)) - + for data in analysis_data: # Determine complexity num_links = data['num_links'] @@ -715,44 +763,60 @@ def _print_detailed_analysis(df: pd.DataFrame, all_query_metrics: List[Dict[str, complexity = "Multi-hop" else: complexity = "Complex" - + # Extract individual reasoning tags tags = [tag.strip() for tag in data['reasoning_types'].split('|')] for tag in tags: complexity_reasoning_data[complexity][tag].append(data['metrics']) - + # Calculate statistics for each complexity-reasoning combination print(f"\n1. REASONING TAG DISTRIBUTION BY COMPLEXITY:") print(f"{'Reasoning Tag':<30} {'Simple':>10} {'Multi-hop':>10} {'Complex':>10} {'Total':>10}") - print("-"*80) - + print("-" * 80) + # Get all unique tags all_tags = set() for complexity_data in complexity_reasoning_data.values(): all_tags.update(complexity_data.keys()) - + tag_distribution = {} for tag in sorted(all_tags): - simple_count = len(complexity_reasoning_data.get('Simple', {}).get(tag, [])) - multihop_count = len(complexity_reasoning_data.get('Multi-hop', {}).get(tag, [])) - complex_count = len(complexity_reasoning_data.get('Complex', {}).get(tag, [])) + simple_count = len( + complexity_reasoning_data.get( + 'Simple', + {}).get( + tag, + [])) + multihop_count = len( + complexity_reasoning_data.get( + 'Multi-hop', + {}).get( + tag, + [])) + complex_count = len( + complexity_reasoning_data.get( + 'Complex', + {}).get( + tag, + [])) total = simple_count + multihop_count + complex_count - + tag_distribution[tag] = { 'simple': simple_count, 'multihop': multihop_count, 'complex': complex_count, 'total': total } - + tag_display = tag[:28] if len(tag) > 28 else tag - print(f"{tag_display:<30} {simple_count:10d} {multihop_count:10d} {complex_count:10d} {total:10d}") - + print( + f"{tag_display:<30} {simple_count:10d} {multihop_count:10d} {complex_count:10d} {total:10d}") + # Calculate percentage distribution print(f"\n2. REASONING TAG PERCENTAGE BY COMPLEXITY:") print(f"{'Reasoning Tag':<30} {'Simple %':>10} {'Multi %':>10} {'Complex %':>10}") - print("-"*80) - + print("-" * 80) + for tag in sorted(all_tags): dist = tag_distribution[tag] total = dist['total'] @@ -760,56 +824,67 @@ def _print_detailed_analysis(df: pd.DataFrame, all_query_metrics: List[Dict[str, simple_pct = (dist['simple'] / total) * 100 multihop_pct = (dist['multihop'] / total) * 100 complex_pct = (dist['complex'] / total) * 100 - + tag_display = tag[:28] if len(tag) > 28 else tag - print(f"{tag_display:<30} {simple_pct:9.1f}% {multihop_pct:9.1f}% {complex_pct:9.1f}%") - + print( + f"{tag_display:<30} {simple_pct:9.1f}% {multihop_pct:9.1f}% {complex_pct:9.1f}%") + # Calculate average number of links per reasoning tag print(f"\n3. AVERAGE COMPLEXITY (# LINKS) BY REASONING TAG:") print(f"{'Reasoning Tag':<30} {'Avg Links':>10} {'Count':>10}") - print("-"*80) - + print("-" * 80) + tag_link_stats = defaultdict(list) for data in analysis_data: tags = [tag.strip() for tag in data['reasoning_types'].split('|')] for tag in tags: tag_link_stats[tag].append(data['num_links']) - + tag_avg_links = [] for tag in sorted(all_tags): links = tag_link_stats[tag] if links: avg_links = sum(links) / len(links) tag_avg_links.append((tag, avg_links, len(links))) - + # Sort by average links (descending) tag_avg_links.sort(key=lambda x: x[1], reverse=True) - + for tag, avg_links, count in tag_avg_links: tag_display = tag[:28] if len(tag) > 28 else tag print(f"{tag_display:<30} {avg_links:10.2f} {count:10d}") - + # Performance by complexity x reasoning tag (for top tags only) print(f"\n4. PERFORMANCE BY COMPLEXITY x TOP REASONING TAGS:") - print("-"*80) - + print("-" * 80) + # Get top 5 most common tags - top_tags = sorted(tag_distribution.items(), key=lambda x: x[1]['total'], reverse=True)[:5] - + top_tags = sorted( + tag_distribution.items(), + key=lambda x: x[1]['total'], + reverse=True)[ + :5] + for tag, _ in top_tags: print(f"\n{tag}:") - print(f"{'Complexity':<12} {'Count':>6} {'P@N':>6} {'R@N':>6} {'F1@N':>6} {'MAP':>6}") - print("-"*70) - + print( + f"{'Complexity':<12} {'Count':>6} {'P@N':>6} {'R@N':>6} {'F1@N':>6} {'MAP':>6}") + print("-" * 70) + for complexity in ["Simple", "Multi-hop", "Complex"]: - metrics_list = complexity_reasoning_data.get(complexity, {}).get(tag, []) + metrics_list = complexity_reasoning_data.get( + complexity, {}).get(tag, []) if metrics_list: count = len(metrics_list) - avg_p = sum(m.get('precision@N', 0.0) for m in metrics_list) / count - avg_r = sum(m.get('recall@N', 0.0) for m in metrics_list) / count + avg_p = sum(m.get('precision@N', 0.0) + for m in metrics_list) / count + avg_r = sum(m.get('recall@N', 0.0) + for m in metrics_list) / count avg_f1 = sum(m.get('f1@N', 0.0) for m in metrics_list) / count - avg_map = sum(m.get('average_precision', 0.0) for m in metrics_list) / count - - print(f"{complexity:<12} {count:6d} {avg_p:6.3f} {avg_r:6.3f} {avg_f1:6.3f} {avg_map:6.3f}") - - print("="*80) \ No newline at end of file + avg_map = sum(m.get('average_precision', 0.0) + for m in metrics_list) / count + + print( + f"{complexity:<12} {count:6d} {avg_p:6.3f} {avg_r:6.3f} {avg_f1:6.3f} {avg_map:6.3f}") + + print("=" * 80) diff --git a/e2e-rag/ingestion_monitor.py b/e2e-rag/ingestion_monitor.py index b3f1e69179..17aa004322 100644 --- a/e2e-rag/ingestion_monitor.py +++ b/e2e-rag/ingestion_monitor.py @@ -20,17 +20,17 @@ Usage: from ingestion_monitor import IngestionMonitor - + monitor = IngestionMonitor() - + # Track document processing with monitor.track_component("html_parsing"): process_html_files(files) - - # Track embedding generation + + # Track embedding generation with monitor.track_component("embedding_generation"): embeddings = generate_embeddings(texts) - + # Get performance report report = monitor.get_performance_report() """ @@ -43,6 +43,7 @@ from dataclasses import dataclass, asdict from pathlib import Path + @dataclass class ComponentMetrics: """Metrics for a single pipeline component.""" @@ -56,6 +57,7 @@ class ComponentMetrics: is_pipeline_input: bool = False # Mark if this is a pipeline input component is_pipeline_output: bool = False # Mark if this is a pipeline output component + @dataclass class IndexingTrendPoint: """Single data point for indexing performance trend.""" @@ -64,7 +66,8 @@ class IndexingTrendPoint: indexing_time: float # Time to add this batch (seconds) throughput_items_per_sec: float cumulative_time: float # Total time so far - + + @dataclass class IngestionReport: """Complete ingestion performance report.""" @@ -75,12 +78,14 @@ class IngestionReport: overall_throughput_mb_per_sec: float components: List[ComponentMetrics] bottleneck_component: str - indexing_trend: List[IndexingTrendPoint] = None # For scaling analysis = "none" + # For scaling analysis = "none" + indexing_trend: List[IndexingTrendPoint] = None bottleneck_component: str + class IngestionMonitor: """Real-time ingestion performance monitoring.""" - + def __init__(self): self.components: Dict[str, ComponentMetrics] = {} self.start_time = None # Will be set when ingestion starts @@ -88,17 +93,17 @@ def __init__(self): self.component_start_time = None self.indexing_trend: List[IndexingTrendPoint] = [] self.cumulative_indexing_time = 0.0 - + def start_ingestion(self): """Mark the start of ingestion. Should be called at the beginning of ingest().""" self.start_time = time.time() - + @contextmanager - def track_component(self, component_name: str, input_size_bytes: int = 0, - items_count: int = 0, text_only: bool = False, - is_pipeline_input: bool = False, is_pipeline_output: bool = False): + def track_component(self, component_name: str, input_size_bytes: int = 0, + items_count: int = 0, text_only: bool = False, + is_pipeline_input: bool = False, is_pipeline_output: bool = False): """Context manager to track performance of a pipeline component. - + Args: component_name: Name of the component being tracked input_size_bytes: Input data size in bytes @@ -108,53 +113,54 @@ def track_component(self, component_name: str, input_size_bytes: int = 0, is_pipeline_output: If True, mark as pipeline output component for aggregation """ start_time = time.time() - + class ComponentContext: def __init__(self): self.input_size_bytes = input_size_bytes self.items_count = items_count self.text_only = text_only - + def set_input_size(self, size_bytes: int): self.input_size_bytes = size_bytes - + def set_item_count(self, count: int): self.items_count = count - + def add_text_bytes(self, text_bytes: int): """Add text-only bytes for passage tracking.""" self.input_size_bytes += text_bytes - + context = ComponentContext() - + try: self.current_component = component_name self.component_start_time = start_time yield context - + finally: end_time = time.time() duration = end_time - start_time - - # Calculate throughput + + # Calculate throughput total_input = context.input_size_bytes - total_output = 0 # Output will be set separately + total_output = 0 # Output will be set separately total_items = context.items_count total_duration = duration - + # Check if component already exists (accumulate metrics) if component_name in self.components: existing = self.components[component_name] # Accumulate metrics total_duration = existing.duration + duration total_input = existing.input_size_bytes + context.input_size_bytes - total_output = existing.output_size_bytes + 0 + total_output = existing.output_size_bytes + 0 total_items = existing.items_processed + context.items_count - + is_pipeline_input = is_pipeline_input or existing.is_pipeline_input is_pipeline_output = is_pipeline_output or existing.is_pipeline_output - throughput_mb = (total_input / (1024 * 1024)) / total_duration if total_duration > 0 else 0 + throughput_mb = (total_input / (1024 * 1024)) / \ + total_duration if total_duration > 0 else 0 throughput_items = total_items / total_duration if total_duration > 0 else 0 self.components[component_name] = ComponentMetrics( @@ -168,18 +174,18 @@ def add_text_bytes(self, text_bytes: int): is_pipeline_input=is_pipeline_input, is_pipeline_output=is_pipeline_output ) - + def set_output_size(self, component_name: str, output_size_bytes: int): """Set the output size for a component after processing.""" if component_name in self.components: self.components[component_name].output_size_bytes = output_size_bytes - + def set_output_size_callback(self, component_name: str, callback_fn): """Set the output size for a component using a callback function. - + This is useful when the output size calculation is complex or requires accessing class-specific data (e.g., BM25 index files). - + Args: component_name: Name of the component callback_fn: Function that returns the output size in bytes @@ -189,40 +195,41 @@ def set_output_size_callback(self, component_name: str, callback_fn): output_size = callback_fn() self.components[component_name].output_size_bytes = output_size except Exception as e: - print(f"Warning: Failed to calculate output size for {component_name}: {e}") - + print( + f"Warning: Failed to calculate output size for {component_name}: {e}") + @contextmanager def track_ingestion(self): """Track overall ingestion performance.""" self.start_time = time.time() # Set start_time for get_performance_report() - + class IngestionContext: def __init__(self): self.item_count = 0 - + def set_item_count(self, count: int): self.item_count = count - + context = IngestionContext() - + try: yield context finally: pass # start_time is checked by get_performance_report() - - def track_incremental_indexing(self, db_size_before: int, batch_size: int, - indexing_time: float): + + def track_incremental_indexing(self, db_size_before: int, batch_size: int, + indexing_time: float): """Track indexing performance for incremental batches to analyze scaling trends. - + Args: db_size_before: Number of items in DB before adding this batch - batch_size: Number of items added in this batch + batch_size: Number of items added in this batch indexing_time: Time taken to index this batch (seconds) """ db_size_after = db_size_before + batch_size throughput = batch_size / indexing_time if indexing_time > 0 else 0 self.cumulative_indexing_time += indexing_time - + trend_point = IndexingTrendPoint( db_size=db_size_after, batch_size=batch_size, @@ -230,36 +237,44 @@ def track_incremental_indexing(self, db_size_before: int, batch_size: int, throughput_items_per_sec=throughput, cumulative_time=self.cumulative_indexing_time ) - + self.indexing_trend.append(trend_point) - + def get_performance_report(self) -> IngestionReport: """Generate comprehensive performance report.""" # Calculate duration from when start_ingestion() was called if self.start_time is None: - raise ValueError("start_ingestion() must be called before getting performance report") - + raise ValueError( + "start_ingestion() must be called before getting performance report") + total_duration = time.time() - self.start_time - + # Aggregate metrics based on pipeline input/output flags # If no flags set, fall back to first component for input - input_components = [c for c in self.components.values() if c.is_pipeline_input] - output_components = [c for c in self.components.values() if c.is_pipeline_output] - + input_components = [ + c for c in self.components.values() if c.is_pipeline_input] + output_components = [ + c for c in self.components.values() if c.is_pipeline_output] + total_input = sum(c.input_size_bytes for c in input_components) total_items = sum(c.items_processed for c in input_components) total_output = sum(c.output_size_bytes for c in output_components) - - overall_throughput = (total_input / (1024 * 1024)) / total_duration if total_duration > 0 else 0 - + + overall_throughput = (total_input / (1024 * 1024)) / \ + total_duration if total_duration > 0 else 0 + # Find bottleneck and calculate efficiency ratio bottleneck_name = "none" - + if self.components: - bottleneck = min(self.components.values(), key=lambda x: x.throughput_mb_per_sec) - fastest = max(self.components.values(), key=lambda x: x.throughput_mb_per_sec) + bottleneck = min( + self.components.values(), + key=lambda x: x.throughput_mb_per_sec) + fastest = max( + self.components.values(), + key=lambda x: x.throughput_mb_per_sec) bottleneck_name = bottleneck.name - + return IngestionReport( total_duration=total_duration, total_input_bytes=total_input, @@ -270,60 +285,72 @@ def get_performance_report(self) -> IngestionReport: bottleneck_component=bottleneck_name, indexing_trend=self.indexing_trend if self.indexing_trend else None ) - + def save_report(self, filename: str = "ingestion_performance.json"): """Save performance report to JSON file.""" report = self.get_performance_report() - + # Convert to serializable format report_dict = asdict(report) - + with open(filename, 'w') as f: json.dump(report_dict, f, indent=2) - + return report_dict - + def print_summary(self): """Print detailed performance summary with individual components.""" report = self.get_performance_report() - + print("🚀 INGESTION PERFORMANCE SUMMARY") print("=" * 60) print(f"📊 Overall Metrics:") print(f" Total duration: {report.total_duration:.2f}s") - print(f" Overall throughput: {report.overall_throughput_mb_per_sec:.2f} MB/s") + print( + f" Overall throughput: {report.overall_throughput_mb_per_sec:.2f} MB/s") print(f" Items processed: {report.total_items:,}") - + # DEBUG: Show detailed breakdown of input data aggregation - input_components = [c for c in report.components if c.is_pipeline_input] + input_components = [ + c for c in report.components if c.is_pipeline_input] print(f"\n🔍 DEBUG: Input Data Breakdown (is_pipeline_input=True):") print(f" {'Component':<30} {'Input Size (MB)':<20} {'Items':<15}") print(f" {'-'*65}") total_input_debug = 0 for comp in input_components: - input_mb = comp.input_size_bytes / (1024*1024) + input_mb = comp.input_size_bytes / (1024 * 1024) total_input_debug += comp.input_size_bytes - print(f" {comp.name:<30} {input_mb:>18.2f} MB {comp.items_processed:>12,}") + print( + f" {comp.name:<30} {input_mb:>18.2f} MB {comp.items_processed:>12,}") print(f" {'-'*65}") - print(f" {'TOTAL AGGREGATED INPUT':<30} {total_input_debug/(1024*1024):>18.2f} MB") - - # Show input data size from report (aggregated from marked input components or first component) - print(f"\n Input data size (from report): {report.total_input_bytes / (1024*1024):.2f} MB") - + print( + f" {'TOTAL AGGREGATED INPUT':<30} {total_input_debug/(1024*1024):>18.2f} MB") + + # Show input data size from report (aggregated from marked input + # components or first component) + print( + f"\n Input data size (from report): {report.total_input_bytes / (1024*1024):.2f} MB") + # Show output size and expansion ratio if output data exists if report.total_output_bytes > 0: - output_size_mb = report.total_output_bytes / (1024*1024) - expansion_ratio = report.total_output_bytes / report.total_input_bytes if report.total_input_bytes > 0 else 0 + output_size_mb = report.total_output_bytes / (1024 * 1024) + expansion_ratio = report.total_output_bytes / \ + report.total_input_bytes if report.total_input_bytes > 0 else 0 print(f" Output data size: {output_size_mb:.2f} MB") print(f" Output/Input ratio: {expansion_ratio:.1f}x") - + print(f" Bottleneck component: {report.bottleneck_component}") - + print(f"\n🔧 Component Performance Details:") - for component in sorted(report.components, key=lambda x: x.duration, reverse=True): - percentage = (component.duration / report.total_duration) * 100 if report.total_duration > 0 else 0 - mb_processed = component.input_size_bytes / (1024*1024) - avg_latency_ms = (component.duration * 1000 / component.items_processed) if component.items_processed > 0 else 0 + for component in sorted( + report.components, key=lambda x: x.duration, reverse=True): + percentage = (component.duration / report.total_duration) * \ + 100 if report.total_duration > 0 else 0 + mb_processed = component.input_size_bytes / (1024 * 1024) + avg_latency_ms = ( + component.duration * + 1000 / + component.items_processed) if component.items_processed > 0 else 0 pipeline_flags = [] if component.is_pipeline_input: pipeline_flags.append("INPUT") @@ -331,33 +358,38 @@ def print_summary(self): pipeline_flags.append("OUTPUT") flag_str = f" [{', '.join(pipeline_flags)}]" if pipeline_flags else "" print(f" 📈 {component.name}{flag_str}:") - print(f" ⏱️ Duration: {component.duration:.3f}s ({percentage:.1f}% of total)") - print(f" 🚀 Throughput: {component.throughput_mb_per_sec:.2f} MB/s") + print( + f" ⏱️ Duration: {component.duration:.3f}s ({percentage:.1f}% of total)") + print( + f" 🚀 Throughput: {component.throughput_mb_per_sec:.2f} MB/s") print(f" 📦 Items: {component.items_processed:,}") print(f" 💾 Data: {mb_processed:.2f} MB") print(f" ⚡ Avg latency: {avg_latency_ms:.2f}ms per item") print() - + # Print indexing trend analysis if available if report.indexing_trend and len(report.indexing_trend) > 1: print("📈 VECTOR DB INDEXING SCALING ANALYSIS") print("=" * 60) print("DB Size → Batch Time (Throughput)") - + for i, point in enumerate(report.indexing_trend): db_size_k = point.db_size // 1000 if point.db_size >= 1000 else point.db_size size_unit = "K" if point.db_size >= 1000 else "" - - print(f" {db_size_k:>4}{size_unit} docs → {point.indexing_time:>6.3f}s ({point.throughput_items_per_sec:>6.1f} docs/sec)") - + + print( + f" {db_size_k:>4}{size_unit} docs → {point.indexing_time:>6.3f}s ({point.throughput_items_per_sec:>6.1f} docs/sec)") + # Calculate scaling trend if len(report.indexing_trend) >= 3: first_point = report.indexing_trend[0] last_point = report.indexing_trend[-1] - - size_ratio = last_point.db_size / first_point.db_size if first_point.db_size > 0 else 0 - time_ratio = last_point.indexing_time / first_point.indexing_time if first_point.indexing_time > 0 else 0 - + + size_ratio = last_point.db_size / \ + first_point.db_size if first_point.db_size > 0 else 0 + time_ratio = last_point.indexing_time / \ + first_point.indexing_time if first_point.indexing_time > 0 else 0 + if size_ratio > 1: scaling_factor = time_ratio / size_ratio if scaling_factor > 1.5: @@ -366,9 +398,10 @@ def print_summary(self): trend_desc = "📊 Linear scaling (time proportional to size)" else: trend_desc = "📉 Sub-linear scaling (indexing gets more efficient)" - + print(f"\n💡 Trend Analysis:") - print(f" Size increased {size_ratio:.1f}x, time increased {time_ratio:.1f}x") + print( + f" Size increased {size_ratio:.1f}x, time increased {time_ratio:.1f}x") print(f" {trend_desc}") print() @@ -376,13 +409,14 @@ def print_summary(self): if __name__ == "__main__": # Example usage monitor = IngestionMonitor() - + # Simulate components - with monitor.track_component("html_parsing", 1024*1024, 100): # 1MB, 100 files + with monitor.track_component("html_parsing", 1024 * 1024, 100): # 1MB, 100 files time.sleep(0.1) - - with monitor.track_component("embedding_generation", 512*1024, 500): # 512KB, 500 chunks + + # 512KB, 500 chunks + with monitor.track_component("embedding_generation", 512 * 1024, 500): time.sleep(0.5) - + monitor.print_summary() monitor.save_report("example_performance.json") diff --git a/e2e-rag/llm_logger.py b/e2e-rag/llm_logger.py index 7207edf1e8..99f1d7e0b3 100644 --- a/e2e-rag/llm_logger.py +++ b/e2e-rag/llm_logger.py @@ -31,7 +31,8 @@ class LLMLogger: """Logger for tracking all LLM calls with full input/output and metrics.""" - def __init__(self, output_file: str = None, experiment_metadata: Dict[str, Any] = None): + def __init__(self, output_file: str = None, + experiment_metadata: Dict[str, Any] = None): self.session_id = str(uuid.uuid4()) self.queries = [] self.output_file = output_file @@ -159,7 +160,8 @@ def _append_query_to_file(self, query_data: Dict): data['queries'].append(query_data) # Update experiment summary - data['experiment_summary'] = self._calculate_experiment_summary(data['queries']) + data['experiment_summary'] = self._calculate_experiment_summary( + data['queries']) # Write back with open(self.output_file, 'w', encoding='utf-8') as f: @@ -187,7 +189,8 @@ def _calculate_experiment_summary(self, queries: List[Dict]) -> Dict: } # Add retrieval/answer metrics if available - queries_with_retrieval = [q for q in queries if "retrieval_results" in q] + queries_with_retrieval = [ + q for q in queries if "retrieval_results" in q] if queries_with_retrieval: experiment_summary["retrieval_metrics"] = { "average_precision": round(sum(q["retrieval_results"].get("precision", 0) for q in queries_with_retrieval) / len(queries_with_retrieval), 4), @@ -197,7 +200,9 @@ def _calculate_experiment_summary(self, queries: List[Dict]) -> Dict: queries_with_answers = [q for q in queries if "answer_results" in q] if queries_with_answers: - correct_count = sum(1 for q in queries_with_answers if q["answer_results"].get("judge_score", 0) >= 4) + correct_count = sum( + 1 for q in queries_with_answers if q["answer_results"].get( + "judge_score", 0) >= 4) experiment_summary["answer_metrics"] = { "average_judge_score": round(sum(q["answer_results"].get("judge_score", 0) for q in queries_with_answers) / len(queries_with_answers), 2), "queries_correct": correct_count, @@ -207,14 +212,17 @@ def _calculate_experiment_summary(self, queries: List[Dict]) -> Dict: return experiment_summary - def end_query(self, retrieval_results: Dict = None, answer_results: Dict = None, wall_time_s: float = None): + def end_query(self, retrieval_results: Dict = None, + answer_results: Dict = None, wall_time_s: float = None): """Finish logging current query, compute summary, and write to file""" if self.current_query: - self.current_query["timestamp_end"] = datetime.utcnow().isoformat() + "Z" + self.current_query["timestamp_end"] = datetime.utcnow( + ).isoformat() + "Z" # Calculate summary llm_calls = self.current_query["llm_calls"] - hop_counts = [c["hop_count"] for c in llm_calls if c["hop_count"] is not None] + hop_counts = [c["hop_count"] + for c in llm_calls if c["hop_count"] is not None] summary = { "total_llm_calls": len(llm_calls), @@ -244,7 +252,8 @@ def end_query(self, retrieval_results: Dict = None, answer_results: Dict = None, self.current_query = None - def save(self, output_file: str = None, experiment_metadata: Dict[str, Any] = None): + def save(self, output_file: str = None, + experiment_metadata: Dict[str, Any] = None): """Save all logs to JSON file (legacy method for backward compatibility). Note: If logger was initialized with output_file, logs are already written @@ -280,21 +289,28 @@ def save(self, output_file: str = None, experiment_metadata: Dict[str, Any] = No print(f"LLM logs saved to: {target_file}") print(f"Total queries: {len(self.queries)}") if experiment_summary: - print(f"Total LLM calls: {experiment_summary.get('total_llm_calls', 0)}") + print( + f"Total LLM calls: {experiment_summary.get('total_llm_calls', 0)}") print(f"Total tokens: {experiment_summary.get('total_tokens', 0):,} (input: {experiment_summary.get('total_input_tokens', 0):,}, output: {experiment_summary.get('total_output_tokens', 0):,})") # Per-query latency distribution (wall time: query to answer) per_query_latencies = sorted( - (q["summary"].get("total_wall_time_ms") or q["summary"].get("total_latency_ms", 0)) / 1000 + (q["summary"].get("total_wall_time_ms") + or q["summary"].get("total_latency_ms", 0)) / 1000 for q in self.queries if "summary" in q ) n = len(per_query_latencies) if n > 0: mean_lat = sum(per_query_latencies) / n - median_lat = per_query_latencies[n // 2] if n % 2 == 1 else (per_query_latencies[n // 2 - 1] + per_query_latencies[n // 2]) / 2 - p90_lat = per_query_latencies[int(n * 0.90)] if n >= 10 else per_query_latencies[-1] - p99_lat = per_query_latencies[int(n * 0.99)] if n >= 100 else per_query_latencies[-1] + median_lat = per_query_latencies[n // 2] if n % 2 == 1 else ( + per_query_latencies[n // 2 - 1] + per_query_latencies[n // 2]) / 2 + p90_lat = per_query_latencies[int( + n * 0.90)] if n >= 10 else per_query_latencies[-1] + p99_lat = per_query_latencies[int( + n * 0.99)] if n >= 100 else per_query_latencies[-1] total_latency_s = sum(per_query_latencies) - print(f"Per-query latency (query-to-answer): mean={mean_lat:.2f}s median={median_lat:.2f}s p90={p90_lat:.2f}s p99={p99_lat:.2f}s") - print(f"Throughput: {n / total_latency_s:.4f} queries/sec ({total_latency_s / n:.2f}s per query)") + print( + f"Per-query latency (query-to-answer): mean={mean_lat:.2f}s median={median_lat:.2f}s p90={p90_lat:.2f}s p99={p99_lat:.2f}s") + print( + f"Throughput: {n / total_latency_s:.4f} queries/sec ({total_latency_s / n:.2f}s per query)") print(f"{'='*80}\n") diff --git a/e2e-rag/measure_indexing_with_chunking.py b/e2e-rag/measure_indexing_with_chunking.py index b8069c24f0..1ef82c2a1a 100644 --- a/e2e-rag/measure_indexing_with_chunking.py +++ b/e2e-rag/measure_indexing_with_chunking.py @@ -188,14 +188,17 @@ def main(): # Validate required arguments if not args.documents and not args.ingest: - parser.error("Either --documents (for raw docs) or --ingest (for pre-chunked passages) is required") + parser.error( + "Either --documents (for raw docs) or --ingest (for pre-chunked passages) is required") # Set default database name if not provided if args.database is None: args.database = VectorDB.get_default_db_name() - db_file_path = args.database if args.database.endswith('.db') else f"{args.database}.db" - db_base_name = args.database.replace('.db', '') if args.database.endswith('.db') else args.database + db_file_path = args.database if args.database.endswith( + '.db') else f"{args.database}.db" + db_base_name = args.database.replace( + '.db', '') if args.database.endswith('.db') else args.database # Check if database already exists db_exists = Path(db_file_path).exists() @@ -249,8 +252,10 @@ def main(): if args.save_passages: passages_file = args.save_passages else: - # Generate filename based on source directory and chunking parameters - source_dir_name = os.path.basename(os.path.normpath(args.documents)) + # Generate filename based on source directory and chunking + # parameters + source_dir_name = os.path.basename( + os.path.normpath(args.documents)) passages_file = f"passages_{source_dir_name}_len{args.chunk_size}_ov{args.chunk_overlap}_{args.text_boundary}.json" print(f" Auto-generated passages filename: {passages_file}") @@ -358,7 +363,8 @@ def main(): # STEP 5: VALIDATE DATABASE (after save, not part of perf) # ============================================================ print("[5/5] Validating database...") - validation_results = validate_database(rag_db, expected_passages=num_passages if args.documents else None) + validation_results = validate_database( + rag_db, expected_passages=num_passages if args.documents else None) vector_count = validation_results["vector_count"] if validation_results["validation_passed"]: @@ -460,7 +466,7 @@ def main(): import shutil try: shutil.rmtree(temp_text_dir) - except: + except BaseException: pass return 0 diff --git a/e2e-rag/multi_shot_retrieval.py b/e2e-rag/multi_shot_retrieval.py index 3ee84c731d..905b83bfbb 100644 --- a/e2e-rag/multi_shot_retrieval.py +++ b/e2e-rag/multi_shot_retrieval.py @@ -28,6 +28,13 @@ Prompt → Query Rewriter (LLM) → k Sub-queries → Retrieval → Reranking → Evaluation """ +import requests +from llm_logger import LLMLogger +from params import add_all_args +from utils import (set_deterministic_seeds, filter_dataset_by_difficulty, + setup_llm_config, get_device_config) +from evaluation import evaluate_retrieval_query, run_evaluation +from retrieve import VectorDB import argparse import json import re @@ -46,13 +53,6 @@ # Get OpenRouter API key from environment OPENROUTER_API_KEY = os.environ.get('OPENROUTER_API_KEY', '') -from retrieve import VectorDB -from evaluation import evaluate_retrieval_query, run_evaluation -from utils import (set_deterministic_seeds, filter_dataset_by_difficulty, - setup_llm_config, get_device_config) -from params import add_all_args -from llm_logger import LLMLogger -import requests # Prompts @@ -171,15 +171,15 @@ **If there are NO NEW documents to evaluate, skip this task and go to TASK 3** -SUMMARY REQUIREMENTS: +SUMMARY REQUIREMENTS: - Extract and preserve specific details - only facts from the document TASK 2: CHECK IF SUFFICIENT AND CONNECT INFORMATION Review ALL KEPT documents and summaries. Actively connect facts across documents: -- Identify entities by matching names across summaries +- Identify entities by matching names across summaries - Chain relationships (A → B → C) -- Cross-reference dates/events +- Cross-reference dates/events - Build complete chains: Person → Family member → Attribute, or Event → Year → Cross-reference If you can construct a complete answer chain with specific names/facts from kept documents, provide final answer. @@ -204,7 +204,7 @@ **For People/Biography:** - Full article: Just the person's name "Harriet Lane" -- Family: "Person X family", "Person X parents" +- Family: "Person X family", "Person X parents" - Specific relative: "Person X" then extract family, don't search "Person X mother" repeatedly **For Events/Dates:** @@ -240,7 +240,7 @@ {{"relevance": [], "summaries": [], "queries": ["Nth position holder name", "specific event list"], "feedback": "Starting with direct entity/list searches."}} CRITICAL REQUIREMENTS: -- If NO NEW documents: return empty arrays: "relevance": [], "summaries": [] +- If NO NEW documents: return empty arrays: "relevance": [], "summaries": [] - If {len_new_docs} NEW documents: return exactly {len_new_docs} relevance scores and {len_new_docs} summaries - For relevant docs (relevance=1): summary MUST extract specific facts (names, dates, relationships, family details from infobox/text) - For irrelevant docs (relevance=0): summary MUST be empty string "" @@ -254,9 +254,6 @@ Respond only in JSON format""" - - - def get_chat_completions_headers(service_url: str): """Headers for an OpenAI-compatible /v1/chat/completions request. @@ -272,21 +269,22 @@ def get_chat_completions_headers(service_url: str): } return {} + def call_chat_completions(service_url: str, model_name: str, messages: List[Dict], - temperature: float = 1.0, max_tokens: int = 4096, - top_p: float = 1.0, - top_k: int = -1, - reasoning_effort: str = "medium", - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, - repetition_penalty: float = 1.0, - max_retries: int = 5, - logger: Optional[LLMLogger] = None, - component: str = "unknown", - hop_count: Optional[int] = None, - context: Dict[str, Any] = None, - perf_test_cache: Optional[Any] = None, - query_id: Optional[str] = None) -> str: + temperature: float = 1.0, max_tokens: int = 4096, + top_p: float = 1.0, + top_k: int = -1, + reasoning_effort: str = "medium", + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + repetition_penalty: float = 1.0, + max_retries: int = 5, + logger: Optional[LLMLogger] = None, + component: str = "unknown", + hop_count: Optional[int] = None, + context: Dict[str, Any] = None, + perf_test_cache: Optional[Any] = None, + query_id: Optional[str] = None) -> str: """Call OpenRouter API with proper authentication and logging. Default sampling parameters: @@ -307,7 +305,8 @@ def call_chat_completions(service_url: str, model_name: str, messages: List[Dict if reasoning_effort != "medium": payload["reasoning_effort"] = reasoning_effort else: - payload["reasoning_effort"] = reasoning_effort # Always include for logging + # Always include for logging + payload["reasoning_effort"] = reasoning_effort # Add optional sampling parameters if frequency_penalty != 0.0: @@ -322,22 +321,32 @@ def call_chat_completions(service_url: str, model_name: str, messages: List[Dict # Performance test mode: check if we have cached response cached_response = None if perf_test_cache and query_id and component and hop_count is not None: - cached_response = perf_test_cache.get_response(query_id, component, hop_count) + cached_response = perf_test_cache.get_response( + query_id, component, hop_count) if cached_response: print(f" [PERF TEST MODE] Will attempt real LLM call for performance measurement, but return cached response for deterministic pipeline") - print(f" [PERF TEST MODE] CRITICAL: LLM call MUST succeed - test will STOP if LLM service is unavailable") + print( + f" [PERF TEST MODE] CRITICAL: LLM call MUST succeed - test will STOP if LLM service is unavailable") else: - print(f" [WARNING] No cached response for {component} hop {hop_count}, will use real LLM response") + print( + f" [WARNING] No cached response for {component} hop {hop_count}, will use real LLM response") for attempt in range(max_retries): start_time = time.time() try: - response = requests.post(service_url, json=payload, headers=headers, timeout=120) + response = requests.post( + service_url, + json=payload, + headers=headers, + timeout=120) # Retry on rate limit if response.status_code == 429: - retry_after = int(response.headers.get('Retry-After', 2 ** attempt)) - print(f" Rate limited (429). Retrying in {retry_after}s (attempt {attempt+1}/{max_retries})") + retry_after = int( + response.headers.get( + 'Retry-After', 2 ** attempt)) + print( + f" Rate limited (429). Retrying in {retry_after}s (attempt {attempt+1}/{max_retries})") time.sleep(retry_after) continue @@ -352,12 +361,14 @@ def call_chat_completions(service_url: str, model_name: str, messages: List[Dict if not llm_output: reasoning_content = message.get('reasoning_content') or '' if reasoning_content: - json_match = re.search(r'\{.*\}', reasoning_content, re.DOTALL) + json_match = re.search( + r'\{.*\}', reasoning_content, re.DOTALL) if json_match: llm_output = json_match.group(0) if not llm_output: - print(f" WARNING [{component}]: LLM returned empty content. Raw response: {json.dumps(result)[:500]}") + print( + f" WARNING [{component}]: LLM returned empty content. Raw response: {json.dumps(result)[:500]}") # Log this call if logger: @@ -368,14 +379,19 @@ def call_chat_completions(service_url: str, model_name: str, messages: List[Dict response=result, latency_ms=latency_ms, context=context or {}, - simulated_response=cached_response # None in normal mode, cached value in perf test mode + # None in normal mode, cached value in perf test mode + simulated_response=cached_response ) - # In perf test mode, return cached response instead of real LLM output + # In perf test mode, return cached response instead of real LLM + # output if cached_response: - print(f" [PERF TEST MODE] Returning simulated response (LLM generated: {len(llm_output)} chars, Simulated: {len(cached_response)} chars)") - print(f" [PERF TEST MODE] Real LLM output: {llm_output[:200]}{'...' if len(llm_output) > 200 else ''}") - print(f" [PERF TEST MODE] Cached output (used in pipeline): {cached_response[:200]}{'...' if len(cached_response) > 200 else ''}") + print( + f" [PERF TEST MODE] Returning simulated response (LLM generated: {len(llm_output)} chars, Simulated: {len(cached_response)} chars)") + print( + f" [PERF TEST MODE] Real LLM output: {llm_output[:200]}{'...' if len(llm_output) > 200 else ''}") + print( + f" [PERF TEST MODE] Cached output (used in pipeline): {cached_response[:200]}{'...' if len(cached_response) > 200 else ''}") return cached_response return llm_output @@ -384,54 +400,71 @@ def call_chat_completions(service_url: str, model_name: str, messages: List[Dict status = e.response.status_code if e.response is not None else None if status in (502, 503, 504) and attempt < max_retries - 1: wait = 2 ** attempt - print(f" Server error ({status}). Retrying in {wait}s (attempt {attempt+1}/{max_retries})") + print( + f" Server error ({status}). Retrying in {wait}s (attempt {attempt+1}/{max_retries})") time.sleep(wait) continue # In perf test mode, LLM calls MUST succeed for valid benchmarking if cached_response: print(f" ERROR [{component}]: HTTP {status}: {e}") - print(f" [PERF TEST MODE FATAL] LLM call failed - cannot proceed with cached response") - print(f" [PERF TEST MODE FATAL] Performance benchmarking requires all LLM calls to succeed for run-to-run equivalency") - print(f" [PERF TEST MODE FATAL] Please ensure LLM service is running on {service_url}") - raise RuntimeError(f"Perf test mode requires LLM service to be available. LLM call failed for {component}") from e + print( + f" [PERF TEST MODE FATAL] LLM call failed - cannot proceed with cached response") + print( + f" [PERF TEST MODE FATAL] Performance benchmarking requires all LLM calls to succeed for run-to-run equivalency") + print( + f" [PERF TEST MODE FATAL] Please ensure LLM service is running on {service_url}") + raise RuntimeError( + f"Perf test mode requires LLM service to be available. LLM call failed for {component}") from e print(f" ERROR [{component}]: HTTP {status}: {e}") raise except requests.exceptions.Timeout: if attempt < max_retries - 1: wait = 2 ** attempt - print(f" Timeout. Retrying in {wait}s (attempt {attempt+1}/{max_retries})") + print( + f" Timeout. Retrying in {wait}s (attempt {attempt+1}/{max_retries})") time.sleep(wait) continue # In perf test mode, LLM calls MUST succeed for valid benchmarking if cached_response: - print(f" ERROR [{component}]: Request timed out after {max_retries} attempts") - print(f" [PERF TEST MODE FATAL] LLM call timed out - cannot proceed with cached response") - print(f" [PERF TEST MODE FATAL] Performance benchmarking requires all LLM calls to succeed for run-to-run equivalency") - raise RuntimeError(f"Perf test mode requires LLM service to respond. LLM call timed out for {component}") - - print(f" ERROR [{component}]: Request timed out after {max_retries} attempts") + print( + f" ERROR [{component}]: Request timed out after {max_retries} attempts") + print( + f" [PERF TEST MODE FATAL] LLM call timed out - cannot proceed with cached response") + print( + f" [PERF TEST MODE FATAL] Performance benchmarking requires all LLM calls to succeed for run-to-run equivalency") + raise RuntimeError( + f"Perf test mode requires LLM service to respond. LLM call timed out for {component}") + + print( + f" ERROR [{component}]: Request timed out after {max_retries} attempts") raise except Exception as e: # In perf test mode, LLM calls MUST succeed for valid benchmarking if cached_response: print(f" ERROR [{component}]: {e}") - print(f" [PERF TEST MODE FATAL] LLM call failed with exception - cannot proceed with cached response") - print(f" [PERF TEST MODE FATAL] Performance benchmarking requires all LLM calls to succeed for run-to-run equivalency") - print(f" [PERF TEST MODE FATAL] Please ensure LLM service is running and accessible") - raise RuntimeError(f"Perf test mode requires LLM service to be available. LLM call failed for {component}") from e + print( + f" [PERF TEST MODE FATAL] LLM call failed with exception - cannot proceed with cached response") + print( + f" [PERF TEST MODE FATAL] Performance benchmarking requires all LLM calls to succeed for run-to-run equivalency") + print( + f" [PERF TEST MODE FATAL] Please ensure LLM service is running and accessible") + raise RuntimeError( + f"Perf test mode requires LLM service to be available. LLM call failed for {component}") from e print(f" ERROR [{component}]: {e}") raise print(f" ERROR [{component}]: Max retries ({max_retries}) exceeded") if cached_response: - print(f" [PERF TEST MODE] Real LLM call failed, returning cached response") + print( + f" [PERF TEST MODE] Real LLM call failed, returning cached response") return cached_response raise RuntimeError(f"LLM call failed after {max_retries} retries") + def evaluate_document_relevance(question: str, new_documents: List[tuple], kept_documents: List[tuple], @@ -512,15 +545,18 @@ def evaluate_document_relevance(question: str, relevance = relevance_result.get("relevance", []) if len(relevance) != len(new_documents): - print(f" Warning: Relevance mismatch. Expected {len(new_documents)}, got {len(relevance)}") + print( + f" Warning: Relevance mismatch. Expected {len(new_documents)}, got {len(relevance)}") return {"relevance": [1] * len(new_documents)} return {"relevance": relevance} except Exception as e: # In perf test mode, propagate fatal errors (don't fallback) - perf_test_cache = llm_config.get('perf_test_cache') if llm_config else None - if perf_test_cache and isinstance(e, RuntimeError) and "Perf test mode requires" in str(e): + perf_test_cache = llm_config.get( + 'perf_test_cache') if llm_config else None + if perf_test_cache and isinstance( + e, RuntimeError) and "Perf test mode requires" in str(e): # This is a perf test mode fatal error - must propagate it raise @@ -529,13 +565,13 @@ def evaluate_document_relevance(question: str, def check_sufficiency(question: str, - kept_documents: List[tuple], - iteration: int, - max_iterations: int, - llm_config: Optional[Dict[str, Any]] = None, - logger: Optional[LLMLogger] = None, - hop_count: int = 1, - query_id: Optional[str] = None) -> Dict[str, Any]: + kept_documents: List[tuple], + iteration: int, + max_iterations: int, + llm_config: Optional[Dict[str, Any]] = None, + logger: Optional[LLMLogger] = None, + hop_count: int = 1, + query_id: Optional[str] = None) -> Dict[str, Any]: """ Check if kept documents are sufficient to answer the question. Uses gpt-oss-120b model via OpenRouter. @@ -601,13 +637,15 @@ def check_sufficiency(question: str, query_id=query_id ) - print(f" [DEBUG] Sufficiency check raw output: {llm_output[:200]}...") + print( + f" [DEBUG] Sufficiency check raw output: {llm_output[:200]}...") if not llm_output: print(f" Warning: Sufficiency check returned empty") # On final iteration, force sufficient if iteration >= max_iterations: - return {"sufficient": True, "reasoning": "Max iterations reached"} + return {"sufficient": True, + "reasoning": "Max iterations reached"} return {"sufficient": False, "reasoning": "LLM returned empty"} if llm_output.startswith("```"): @@ -635,24 +673,27 @@ def check_sufficiency(question: str, except Exception as e: # In perf test mode, propagate fatal errors (don't fallback) - perf_test_cache = llm_config.get('perf_test_cache') if llm_config else None - if perf_test_cache and isinstance(e, RuntimeError) and "Perf test mode requires" in str(e): + perf_test_cache = llm_config.get( + 'perf_test_cache') if llm_config else None + if perf_test_cache and isinstance( + e, RuntimeError) and "Perf test mode requires" in str(e): # This is a perf test mode fatal error - must propagate it raise print(f" Error in sufficiency check: {e}") # On final iteration, force sufficient if iteration >= max_iterations: - return {"sufficient": True, "reasoning": f"Max iterations reached (error: {str(e)})"} + return {"sufficient": True, + "reasoning": f"Max iterations reached (error: {str(e)})"} return {"sufficient": False, "reasoning": f"Error: {str(e)}"} def generate_answer(question: str, - kept_documents: List[tuple], - llm_config: Optional[Dict[str, Any]] = None, - logger: Optional[LLMLogger] = None, - hop_count: Optional[int] = None, - query_id: Optional[str] = None) -> str: + kept_documents: List[tuple], + llm_config: Optional[Dict[str, Any]] = None, + logger: Optional[LLMLogger] = None, + hop_count: Optional[int] = None, + query_id: Optional[str] = None) -> str: """ Generate final answer from kept documents using gpt-oss-120b. @@ -726,9 +767,12 @@ def generate_answer(question: str, return llm_output.strip() except Exception as e: - # In perf test mode, propagate fatal errors (don't fallback to "Unknown") - perf_test_cache = llm_config.get('perf_test_cache') if llm_config else None - if perf_test_cache and isinstance(e, RuntimeError) and "Perf test mode requires" in str(e): + # In perf test mode, propagate fatal errors (don't fallback to + # "Unknown") + perf_test_cache = llm_config.get( + 'perf_test_cache') if llm_config else None + if perf_test_cache and isinstance( + e, RuntimeError) and "Perf test mode requires" in str(e): # This is a perf test mode fatal error - must propagate it raise @@ -737,15 +781,15 @@ def generate_answer(question: str, def generate_search_queries(question: str, - kept_documents: List[tuple], - max_queries: int = 3, - query_history: Optional[List[str]] = None, - query_results: Optional[List[int]] = None, - feedback_history: Optional[List[str]] = None, - llm_config: Optional[Dict[str, Any]] = None, - logger: Optional[LLMLogger] = None, - hop_count: int = 1, - query_id: Optional[str] = None) -> Dict[str, Any]: + kept_documents: List[tuple], + max_queries: int = 3, + query_history: Optional[List[str]] = None, + query_results: Optional[List[int]] = None, + feedback_history: Optional[List[str]] = None, + llm_config: Optional[Dict[str, Any]] = None, + logger: Optional[LLMLogger] = None, + hop_count: int = 1, + query_id: Optional[str] = None) -> Dict[str, Any]: """ Generate search queries using gpt-oss-120b via OpenRouter. """ @@ -771,7 +815,8 @@ def generate_search_queries(question: str, history_text = "No queries yet" # Format feedback - feedback_text = "\n".join(feedback_history) if feedback_history else "Iteration 1 - Initial search" + feedback_text = "\n".join( + feedback_history) if feedback_history else "Iteration 1 - Initial search" prompt = QUERY_GENERATION_PROMPT.format( question=question, @@ -833,9 +878,12 @@ def generate_search_queries(question: str, } except Exception as e: - # In perf test mode, propagate fatal errors (don't fallback to original question) - perf_test_cache = llm_config.get('perf_test_cache') if llm_config else None - if perf_test_cache and isinstance(e, RuntimeError) and "Perf test mode requires" in str(e): + # In perf test mode, propagate fatal errors (don't fallback to original + # question) + perf_test_cache = llm_config.get( + 'perf_test_cache') if llm_config else None + if perf_test_cache and isinstance( + e, RuntimeError) and "Perf test mode requires" in str(e): # This is a perf test mode fatal error - must propagate it raise @@ -854,7 +902,7 @@ def query_rewriter(question: str, new_documents: List[tuple], llm_config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: """ Evaluates documents AND generates new queries in one LLM call. - + Args: question: The user's original question new_documents: List of NEW document texts to evaluate @@ -864,7 +912,7 @@ def query_rewriter(question: str, new_documents: List[tuple], query_history: List of previous search queries query_results: List of number of documents found for each query (parallel to query_history) previous_feedback: Feedback from previous iteration about what's missing - + Returns: Dict with: - 'relevance' (list of 0/1 for ONLY new_documents) @@ -879,18 +927,18 @@ def query_rewriter(question: str, new_documents: List[tuple], kept_context += f"\n[KEPT {i}] {doc[2]}\n" else: kept_context = "None" - - # Format NEW documents + + # Format NEW documents new_context = "" if new_documents: for i, doc in enumerate(new_documents, 1): new_context += f"\n[NEW {i}] {doc[1]}\n" else: new_context = "None" - + # Combine for context context = f"KEPT DOCUMENTS (already relevant):\n{kept_context}\n\nNEW DOCUMENTS (evaluate these):\n{new_context}" - + # Format query history with results - focus on failures for learning if query_history: failed_queries = [] @@ -901,31 +949,35 @@ def query_rewriter(question: str, new_documents: List[tuple], failed_queries.append(q) else: successful_queries.append(f"{q} ({num_docs} docs)") - + history_parts = [] if failed_queries: - history_parts.append(f"FAILED: {', '.join(failed_queries)}") # Last 3 failures + history_parts.append( + f"FAILED: {', '.join(failed_queries)}") # Last 3 failures if successful_queries: - history_parts.append(f"SUCCESS: {', '.join(successful_queries)}") # Last 2 successes - - history_text = "; ".join(history_parts) if history_parts else "No queries yet" + # Last 2 successes + history_parts.append(f"SUCCESS: {', '.join(successful_queries)}") + + history_text = "; ".join( + history_parts) if history_parts else "No queries yet" else: history_text = "No queries yet" - + # Build feedback history - show progression of what was tried and learned if feedback_history and len(feedback_history) > 0: unique_feedback = [] for fb in reversed(feedback_history): if fb and fb not in unique_feedback: unique_feedback.append(fb) - + if unique_feedback: - feedback_text = "PREVIOUS ATTEMPTS: " + " → ".join(reversed(unique_feedback)) + feedback_text = "PREVIOUS ATTEMPTS: " + \ + " → ".join(reversed(unique_feedback)) else: feedback_text = f"Iteration {len(query_history) + 1 if query_history else 1}" else: feedback_text = f"Iteration {len(query_history) + 1 if query_history else 1} - Initial search" - + print(f"Context: {context}") print(f"History: {history_text}") print(f"Feedback: {feedback_text}") @@ -938,12 +990,12 @@ def query_rewriter(question: str, new_documents: List[tuple], k=max_queries, len_new_docs=len(new_documents) ) - - system_message = f"""You are an expert at multi-hop reasoning and strategic search. - CRITICAL: Never repeat failed queries. - Always try completely different approaches when queries return 0 docs. + + system_message = f"""You are an expert at multi-hop reasoning and strategic search. + CRITICAL: Never repeat failed queries. + Always try completely different approaches when queries return 0 docs. Focus on atomic facts and progressive strategies.""" - + # Use LLM config if provided, otherwise use defaults if llm_config: model_name = llm_config["model_name"] @@ -973,10 +1025,10 @@ def query_rewriter(question: str, new_documents: List[tuple], response = requests.post(service_url, json=payload, timeout=300) response.raise_for_status() result = response.json() - + message = result['choices'][0]['message'] llm_output = message.get('content') - + # Fallback: use reasoning_content if content is empty (thinking models) reasoning_content = message.get('reasoning_content', '') if reasoning_content and not llm_output: @@ -985,12 +1037,15 @@ def query_rewriter(question: str, new_documents: List[tuple], json_match = re.search(r'\{.*\}', reasoning_content, re.DOTALL) if json_match: llm_output = json_match.group(0) - print(f" DEBUG: Extracted JSON from reasoning_content ({len(llm_output)} chars)") + print( + f" DEBUG: Extracted JSON from reasoning_content ({len(llm_output)} chars)") else: - print(f" DEBUG: No JSON found in reasoning_content snippet: {reasoning_content[:200]}") + print( + f" DEBUG: No JSON found in reasoning_content snippet: {reasoning_content[:200]}") if llm_output is None or not llm_output.strip(): - print(f" Warning: LLM returned empty content, using original query as fallback") + print( + f" Warning: LLM returned empty content, using original query as fallback") # Always fall back to original query - never return empty queries return { "relevance": [0] * len(new_documents), @@ -999,25 +1054,25 @@ def query_rewriter(question: str, new_documents: List[tuple], "feedback": "LLM returned empty response", "answer": "" } - + llm_output = llm_output.strip() - + # Parse JSON output - handle markdown code blocks if llm_output.startswith("```"): llm_output = llm_output.split("```")[1] if llm_output.startswith("json"): llm_output = llm_output[4:] llm_output = llm_output.strip() - + result_data = json.loads(llm_output) - - # Validate format + + # Validate format required_fields = ["relevance"] for field in required_fields: if field not in result_data: print(f"Warning: Missing required field '{field}' in response") result_data[field] = [0] * len(new_documents) - + # Ensure we have either "answer" OR "queries"+"feedback" if "answer" not in result_data: result_data["answer"] = "" @@ -1027,37 +1082,47 @@ def query_rewriter(question: str, new_documents: List[tuple], result_data["feedback"] = "" if "summaries" not in result_data: result_data["summaries"] = [""] * len(new_documents) - - # Ensure relevance array matches NEW document count - fix mismatches by padding/truncating + + # Ensure relevance array matches NEW document count - fix mismatches by + # padding/truncating if len(result_data["relevance"]) != len(new_documents): - print(f"Warning: Relevance array length mismatch. Expected {len(new_documents)}, got {len(result_data['relevance'])}. Auto-fixing.") - relevance = result_data["relevance"][:len(new_documents)] # Truncate if too long - while len(relevance) < len(new_documents): # Pad with 0s if too short + print( + f"Warning: Relevance array length mismatch. Expected {len(new_documents)}, got {len(result_data['relevance'])}. Auto-fixing.") + relevance = result_data["relevance"][:len( + new_documents)] # Truncate if too long + while len(relevance) < len( + new_documents): # Pad with 0s if too short relevance.append(0) result_data["relevance"] = relevance print(f"Fixed relevance array: {relevance}") - - # Ensure summaries array matches NEW document count - fix mismatches by padding/truncating + + # Ensure summaries array matches NEW document count - fix mismatches by + # padding/truncating if len(result_data["summaries"]) != len(new_documents): - print(f"Warning: Summaries array length mismatch. Expected {len(new_documents)}, got {len(result_data['summaries'])}. Auto-fixing.") - summaries = result_data["summaries"][:len(new_documents)] # Truncate if too long - while len(summaries) < len(new_documents): # Pad with empty strings if too short + print( + f"Warning: Summaries array length mismatch. Expected {len(new_documents)}, got {len(result_data['summaries'])}. Auto-fixing.") + summaries = result_data["summaries"][:len( + new_documents)] # Truncate if too long + while len(summaries) < len( + new_documents): # Pad with empty strings if too short summaries.append("") result_data["summaries"] = summaries print(f"Fixed summaries array length: {len(summaries)}") - + # Validate that relevant documents have non-empty summaries - for i, (rel, summary) in enumerate(zip(result_data["relevance"], result_data["summaries"])): + for i, (rel, summary) in enumerate( + zip(result_data["relevance"], result_data["summaries"])): if rel == 1 and not summary.strip(): - print(f"Warning: Document {i+1} marked relevant but has empty summary. This defeats the summarization purpose.") + print( + f"Warning: Document {i+1} marked relevant but has empty summary. This defeats the summarization purpose.") # Don't auto-fix here - let it be empty to debug the issue - + # Ensure queries is a list if not isinstance(result_data["queries"], list): result_data["queries"] = [] - + return result_data - + except requests.exceptions.RequestException as e: print(f"Error calling combined LLM: {e}") return { @@ -1106,14 +1171,14 @@ def multi_shot_retrieval(rag_db, original_query: str, expected_urls: List[str], **strategy_params) -> Dict[str, Any]: """ Multi-shot retrieval with iterative query refinement and document evaluation. - + Algorithm: 1. Generate initial search queries based on the original question 2. Retrieve documents for each query 3. Evaluate documents and check if sufficient to answer 4. If not sufficient: generate new queries based on what's missing, go to step 2 5. Repeat until sufficient or max_iterations reached - + Args: rag_db: RAG database instance original_query: Original user question @@ -1127,29 +1192,31 @@ def multi_shot_retrieval(rag_db, original_query: str, expected_urls: List[str], verbose: Print detailed information reasoning_effort: LLM reasoning level **strategy_params: Additional parameters for retrieval strategy - + Returns: Dictionary containing evaluation metrics and iteration statistics """ - + start_time = time.perf_counter() llm_start_time = None # set just before first generate_search_queries call llm_end_time = None # set just after generate_answer returns - + # Track iteration history query_history = [] query_results = [] # Track how many docs each query found - kept_docs = [] # List of (url, content, summary) tuples that were marked relevant - new_docs = [] # List of (url, content) tuples just retrieved this iteration + # List of (url, content, summary) tuples that were marked relevant + kept_docs = [] + # List of (url, content) tuples just retrieved this iteration + new_docs = [] all_retrieved_urls = set() iteration_times = [] previous_feedback = "" # Feedback from previous iteration feedback_history = [] # Track all feedback to show progression - + sufficient = False iteration = 0 final_answer = "" - + if verbose: print(f"\n{'='*80}") print(f"MULTI-SHOT RETRIEVAL") @@ -1158,29 +1225,33 @@ def multi_shot_retrieval(rag_db, original_query: str, expected_urls: List[str], print(f"Max iterations: {max_iterations}") print(f"Max sub-queries per iteration: {max_sub_queries}") print(f"{'='*80}\n") - + while not sufficient and iteration < max_iterations: iteration += 1 iteration_start = time.perf_counter() - + if verbose: print(f"\n{'─'*80}") print(f"ITERATION {iteration}/{max_iterations}") print(f"{'─'*80}") - - # Step 1: Use combined function to grade NEW docs AND generate new queries + + # Step 1: Use combined function to grade NEW docs AND generate new + # queries if verbose: print(f"\n Evaluating documents and generating queries...") - - # Aggressive summarization: use summaries after iteration 2 to improve information connection + + # Aggressive summarization: use summaries after iteration 2 to improve + # information connection total_content_length = sum(len(doc[1]) for doc in kept_docs) # Special handling for iteration 1: decompose original query first if iteration == 1 and not new_docs and not kept_docs: if verbose: - print(f" [ITERATION 1] Decomposing original query into sub-queries via generate_search_queries...") + print( + f" [ITERATION 1] Decomposing original query into sub-queries via generate_search_queries...") - # Use generate_search_queries for initial decomposition (uses query_model_name / gpt-oss-120b) + # Use generate_search_queries for initial decomposition (uses + # query_model_name / gpt-oss-120b) if llm_start_time is None: llm_start_time = time.perf_counter() query_result = generate_search_queries( @@ -1207,7 +1278,8 @@ def multi_shot_retrieval(rag_db, original_query: str, expected_urls: List[str], sufficient = False final_answer = "" - current_feedback = query_result.get("feedback", "Initial query decomposition") + current_feedback = query_result.get( + "feedback", "Initial query decomposition") relevance = [] summaries = [] reasoning_steps = "" @@ -1217,7 +1289,8 @@ def multi_shot_retrieval(rag_db, original_query: str, expected_urls: List[str], relevance = [] if new_docs: if verbose: - print(f" [CALL1] Evaluating {len(new_docs)} new documents with gpt-oss-20b...") + print( + f" [CALL1] Evaluating {len(new_docs)} new documents with gpt-oss-20b...") relevance_result = evaluate_document_relevance( question=original_query, @@ -1228,7 +1301,8 @@ def multi_shot_retrieval(rag_db, original_query: str, expected_urls: List[str], hop_count=iteration, query_id=query_id ) - relevance = relevance_result.get("relevance", [1] * len(new_docs)) + relevance = relevance_result.get( + "relevance", [1] * len(new_docs)) # Add relevant docs to kept_docs IMMEDIATELY for i, (url, content) in enumerate(new_docs): @@ -1236,14 +1310,16 @@ def multi_shot_retrieval(rag_db, original_query: str, expected_urls: List[str], kept_docs.append((url, content, content[:1000])) if verbose: - print(f" Marked {sum(relevance)} of {len(new_docs)} docs as relevant") + print( + f" Marked {sum(relevance)} of {len(new_docs)} docs as relevant") print(f" Relevance array: {relevance}") print(f" Total kept docs now: {len(kept_docs)}") # CALL 2: Check sufficiency - uses gpt-oss-120b if kept_docs: if verbose: - print(f" [CALL2] Checking sufficiency with gpt-oss-120b (iteration {iteration}/{max_iterations})...") + print( + f" [CALL2] Checking sufficiency with gpt-oss-120b (iteration {iteration}/{max_iterations})...") sufficiency_result = check_sufficiency( question=original_query, @@ -1267,11 +1343,13 @@ def multi_shot_retrieval(rag_db, original_query: str, expected_urls: List[str], sufficient = False sufficiency_reasoning = "No relevant documents kept yet" - # CALL 3a or 3b: Either generate answer (if sufficient) or generate queries (if not) + # CALL 3a or 3b: Either generate answer (if sufficient) or generate + # queries (if not) if sufficient: # CALL 3a: Generate final answer - uses gpt-oss-120b if verbose: - print(f" [CALL3a] Generating final answer with gpt-oss-120b...") + print( + f" [CALL3a] Generating final answer with gpt-oss-120b...") final_answer = generate_answer( question=original_query, @@ -1291,11 +1369,13 @@ def multi_shot_retrieval(rag_db, original_query: str, expected_urls: List[str], else: # CALL 3b: Generate search queries - uses gpt-oss-120b if verbose: - print(f" [CALL3b] Generating search queries with gpt-oss-120b...") + print( + f" [CALL3b] Generating search queries with gpt-oss-120b...") # Cap kept_docs sent to avoid context overflow MAX_DOCS_FOR_QUERY_GEN = 12 - docs_for_query_gen = kept_docs[-MAX_DOCS_FOR_QUERY_GEN:] if len(kept_docs) > MAX_DOCS_FOR_QUERY_GEN else kept_docs + docs_for_query_gen = kept_docs[-MAX_DOCS_FOR_QUERY_GEN:] if len( + kept_docs) > MAX_DOCS_FOR_QUERY_GEN else kept_docs query_result = generate_search_queries( question=original_query, @@ -1320,27 +1400,30 @@ def multi_shot_retrieval(rag_db, original_query: str, expected_urls: List[str], summaries = [] reasoning_steps = "" - # Add to feedback history if it's new and meaningful - if current_feedback and current_feedback.strip() and current_feedback != previous_feedback: + if current_feedback and current_feedback.strip( + ) and current_feedback != previous_feedback: feedback_history.append(current_feedback.strip()) previous_feedback = current_feedback # Only print status for iteration 1 after we've printed the queries # For iteration 2+, print status after CALL1/CALL2/CALL3 - # Skip printing here for iteration 1 (will print later after retrieval/grading) + # Skip printing here for iteration 1 (will print later after + # retrieval/grading) if verbose and iteration > 1: print(f" Sufficient: {'yes' if sufficient else 'no'}") print(f" Kept docs: {len(kept_docs)}") if new_docs: print(f" New docs evaluated: {len(new_docs)}") - print(f" Relevant new docs: {sum(relevance)}/{len(relevance)}") + print( + f" Relevant new docs: {sum(relevance)}/{len(relevance)}") print(f" Relevance array: {relevance}") # Show summary quality if summaries: non_empty_summaries = [s for s in summaries if s.strip()] - print(f" Generated summaries: {len(non_empty_summaries)}/{len(summaries)} non-empty") + print( + f" Generated summaries: {len(non_empty_summaries)}/{len(summaries)} non-empty") for i, summary in enumerate(summaries): if summary.strip() and relevance[i] == 1: print(f" Summary {i+1}: {summary}...") @@ -1349,10 +1432,11 @@ def multi_shot_retrieval(rag_db, original_query: str, expected_urls: List[str], if not sufficient: print(f" Feedback: {previous_feedback}") print(f" Generated {len(sub_queries)} new queries") - - # Clear new_docs for next iteration (already added to kept_docs in CALL1 block above) + + # Clear new_docs for next iteration (already added to kept_docs in + # CALL1 block above) new_docs = [] - + # If sufficient, we're done if sufficient: if verbose: @@ -1361,21 +1445,22 @@ def multi_shot_retrieval(rag_db, original_query: str, expected_urls: List[str], print(f" Answer: {final_answer[:200]}...") iteration_times.append(time.perf_counter() - iteration_start) break - - # If no queries generated, fall back to original query rather than stopping + + # If no queries generated, fall back to original query rather than + # stopping if not sub_queries: if verbose: print(f"\n ⚠ No new queries generated, falling back to original query") sub_queries = [original_query] - + if verbose: print(f"\n New queries:") for i, q in enumerate(sub_queries, 1): print(f" {i}. {q}") - + # Step 2: Retrieve for each sub-query and track results num_sub_queries = len(sub_queries) - #docs_per_subquery = max(1, top_k_retriever // num_sub_queries) + # docs_per_subquery = max(1, top_k_retriever // num_sub_queries) docs_per_subquery = max(1, top_k_retriever) iteration_results = [] @@ -1383,45 +1468,52 @@ def multi_shot_retrieval(rag_db, original_query: str, expected_urls: List[str], # Calculate target docs per subquery after reranking target_docs_per_subquery = max(3, top_k_retriever // num_sub_queries) - + for i, sub_query in enumerate(sub_queries, 1): if verbose: print(f"\n Retrieving for query {i}: {sub_query[:60]}...") - + query_start_count = len(new_docs) # Track docs before this query - + # Retrieve if retrieval_strategy == "fixed_k": results = rag_db.lookup(sub_query, k=docs_per_subquery) else: from retrieve.filter import filter original_max_results = strategy_params.get("max_results", 20) - #adjusted_max_results = max(1, original_max_results // num_sub_queries) + # adjusted_max_results = max(1, original_max_results // num_sub_queries) adjusted_max_results = max(1, original_max_results) strategy_params_copy = strategy_params.copy() strategy_params_copy["max_results"] = adjusted_max_results - results = filter(rag_db, sub_query, method=retrieval_strategy, **strategy_params_copy) - + results = filter( + rag_db, + sub_query, + method=retrieval_strategy, + **strategy_params_copy) + # Apply per-subquery reranking if enabled if not no_rerank and len(results) > target_docs_per_subquery: if verbose: - print(f" Reranking {len(results)} docs for this subquery to top {target_docs_per_subquery}...") - + print( + f" Reranking {len(results)} docs for this subquery to top {target_docs_per_subquery}...") + # Extract contents for reranking contents = [r.page_content for r in results] scored_passages = rag_db.rerank(sub_query, contents) - + # Reorder results by reranking scores and take top-k - reranked_indices = [i for i, _ in sorted(enumerate(scored_passages), + reranked_indices = [i for i, _ in sorted(enumerate(scored_passages), key=lambda x: x[1][1], reverse=True)] - results = [results[idx] for idx in reranked_indices[:target_docs_per_subquery]] - + results = [results[idx] + for idx in reranked_indices[:target_docs_per_subquery]] + if verbose: - print(f" After reranking: keeping top {len(results)} docs") + print( + f" After reranking: keeping top {len(results)} docs") elif len(results) > target_docs_per_subquery: # No reranking, just limit to target results = results[:target_docs_per_subquery] - + # Add to new_docs for evaluation (avoid duplicates) for result in results: if 'original_url' in result.metadata and result.metadata['original_url']: @@ -1430,27 +1522,30 @@ def multi_shot_retrieval(rag_db, original_query: str, expected_urls: List[str], all_retrieved_urls.add(url) new_docs.append((url, result.page_content)) iteration_results.append(result) - + # Track how many NEW docs this query found docs_found_by_query = len(new_docs) - query_start_count per_query_counts.append(docs_found_by_query) - + if verbose: - print(f" Retrieved {len(results)} docs, {docs_found_by_query} new unique docs from this query") + print( + f" Retrieved {len(results)} docs, {docs_found_by_query} new unique docs from this query") for j, result in enumerate(results, 1): url = result.metadata.get('original_url', 'N/A') passage = result.page_content[:300].replace('\n', ' ') print(f" [{j}] {url}\n {passage}...") - + # Add queries and their results to history for sub_query, count in zip(sub_queries, per_query_counts): query_history.append(sub_query) query_results.append(count) - + if verbose: - print(f" Total kept docs: {len(kept_docs)}, new docs to evaluate: {len(new_docs)}") + print( + f" Total kept docs: {len(kept_docs)}, new docs to evaluate: {len(new_docs)}") - # For iteration 1, print status summary now (after retrieval, before next iteration's grading) + # For iteration 1, print status summary now (after retrieval, before + # next iteration's grading) if verbose and iteration == 1: print(f"\n Iteration 1 Summary:") print(f" Generated {len(sub_queries)} new queries") @@ -1458,15 +1553,15 @@ def multi_shot_retrieval(rag_db, original_query: str, expected_urls: List[str], iteration_time = time.perf_counter() - iteration_start iteration_times.append(iteration_time) - + if iteration >= max_iterations: if verbose: print(f"\n ⚠ Maximum iterations reached") break - + # Final processing total_time = time.perf_counter() - start_time - + # Extract URLs from kept_docs retrieved_urls = [] for doc in kept_docs: @@ -1474,17 +1569,20 @@ def multi_shot_retrieval(rag_db, original_query: str, expected_urls: List[str], retrieved_urls.append(doc[0]) # url is first element elif len(doc) == 2: # Handle old format for backward compatibility retrieved_urls.append(doc[0]) # url is first element - + # Limit to top_k_reranking (reranking already done per-subquery) retrieved_urls = retrieved_urls[:top_k_reranking] - + # Calculate metrics from evaluation import calculate_retrieval_metrics expected_set = set(url for url in expected_urls if url and url.strip()) metrics = calculate_retrieval_metrics(list(expected_set), retrieved_urls) - + # Add iteration statistics - query_llm_time = (llm_end_time - llm_start_time) if (llm_start_time is not None and llm_end_time is not None) else total_time + query_llm_time = ( + llm_end_time - + llm_start_time) if ( + llm_start_time is not None and llm_end_time is not None) else total_time metrics.update({ 'total_time': total_time, 'query_llm_time': query_llm_time, @@ -1495,7 +1593,7 @@ def multi_shot_retrieval(rag_db, original_query: str, expected_urls: List[str], 'avg_iteration_time': sum(iteration_times) / len(iteration_times) if iteration_times else 0, 'llm_answer': final_answer, }) - + # Print final results if verbose: print(f"\n{'='*80}") @@ -1509,8 +1607,10 @@ def multi_shot_retrieval(rag_db, original_query: str, expected_urls: List[str], print(f"LLM Answer: {final_answer}") if expected_answer: print(f"Expected Answer: {expected_answer}") - print(f"Expected ({len(expected_set)}): {sorted(list(expected_set)[:3])}{'...' if len(expected_set) > 3 else ''}") - print(f"Retrieved ({len(retrieved_urls)} unique docs): {retrieved_urls[:3]}{'...' if len(retrieved_urls) > 3 else ''}") + print( + f"Expected ({len(expected_set)}): {sorted(list(expected_set)[:3])}{'...' if len(expected_set) > 3 else ''}") + print( + f"Retrieved ({len(retrieved_urls)} unique docs): {retrieved_urls[:3]}{'...' if len(retrieved_urls) > 3 else ''}") matches = len(expected_set.intersection(set(retrieved_urls))) print(f"Matches: {matches}") print(f"\nMetrics:") @@ -1519,10 +1619,11 @@ def multi_shot_retrieval(rag_db, original_query: str, expected_urls: List[str], print(f" F1@N: {metrics.get('f1@N', 0.0):.3f}") print(f" MAP: {metrics.get('average_precision', 0.0):.3f}") print(f"\nTiming:") - print(f" Avg per iteration: {metrics['avg_iteration_time']*1000:.1f}ms") + print( + f" Avg per iteration: {metrics['avg_iteration_time']*1000:.1f}ms") print(f" Total: {total_time*1000:.1f}ms") print(f"{'='*80}\n") - + return metrics @@ -1543,7 +1644,7 @@ def run_multi_shot_evaluation(rag_db, dataset_path: str, **strategy_params) -> Dict[str, float]: """ Run multi-shot evaluation on a dataset. - + Args: rag_db: RAG database instance dataset_path: Path to dataset TSV file @@ -1558,21 +1659,21 @@ def run_multi_shot_evaluation(rag_db, dataset_path: str, difficulty: Minimum number of answer links required (0 = no filtering) max_iterations: Maximum iterations for iterative retrieval (default: 10) **strategy_params: Additional parameters for retrieval strategy - + Returns: Dictionary of averaged metrics """ - + df = pd.read_csv(dataset_path, sep='\t') - + # Filter by difficulty if specified df = filter_dataset_by_difficulty(df, difficulty) - + if isinstance(max_queries, int) and max_queries > 0: df = df.head(max_queries) else: max_queries = len(df) - + print(f"\n{'='*80}") print(f"MULTI-SHOT EVALUATION") print(f"{'='*80}") @@ -1586,7 +1687,7 @@ def run_multi_shot_evaluation(rag_db, dataset_path: str, if difficulty > 0: print(f"Difficulty filter: >= {difficulty} answer links") print(f"{'='*80}\n") - + eval_wall_start = time.perf_counter() total_metrics = {} valid_queries = 0 @@ -1600,9 +1701,12 @@ def run_multi_shot_evaluation(rag_db, dataset_path: str, for col in df.columns: if col.startswith('wikipedia_link_') and pd.notna(row[col]): expected_urls.append(row[col].strip()) - expected_answer = row.get('Answer', '').strip() if 'Answer' in row and pd.notna(row.get('Answer')) else "" + expected_answer = row.get( + 'Answer', '').strip() if 'Answer' in row and pd.notna( + row.get('Answer')) else "" if expected_urls: - work_items.append((idx, row['Prompt'], expected_urls, expected_answer)) + work_items.append( + (idx, row['Prompt'], expected_urls, expected_answer)) def process_single_query(item): idx, prompt, expected_urls, expected_answer = item @@ -1630,7 +1734,8 @@ def process_single_query(item): ) query_wall_time = metrics.get('total_time', 0.0) - print(f" [QUERY TIMING] Query {idx+1}/{max_queries}: {query_wall_time:.2f}s") + print( + f" [QUERY TIMING] Query {idx+1}/{max_queries}: {query_wall_time:.2f}s") if logger: logger.end_query( @@ -1672,7 +1777,10 @@ def process_single_query(item): # Parallel execution with thread pool print(f"\n Using {num_workers} parallel workers") with ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = {executor.submit(process_single_query, item): item for item in work_items} + futures = { + executor.submit( + process_single_query, + item): item for item in work_items} for future in as_completed(futures): try: idx, metrics, result = future.result() @@ -1689,45 +1797,60 @@ def process_single_query(item): print(f" Error processing query: {e}") import traceback traceback.print_exc() - + if valid_queries > 0: # Calculate averages - avg_metrics = {name: total / valid_queries for name, total in total_metrics.items()} - + avg_metrics = { + name: total / + valid_queries for name, + total in total_metrics.items()} + # Print summary print(f"\n{'='*80}") print(f"MULTI-SHOT EVALUATION SUMMARY ({valid_queries} queries)") print(f"{'='*80}") print(f"\nPRECISION METRICS:") - print(f" Precision@N: {avg_metrics.get('precision@N', 0.0):.3f}") + print( + f" Precision@N: {avg_metrics.get('precision@N', 0.0):.3f}") print(f"\nRECALL METRICS:") - print(f" Recall@N: {avg_metrics.get('recall@N', 0.0):.3f}") + print( + f" Recall@N: {avg_metrics.get('recall@N', 0.0):.3f}") print(f"\nF1 METRICS:") - print(f" F1@N: {avg_metrics.get('f1@N', 0.0):.3f}") + print( + f" F1@N: {avg_metrics.get('f1@N', 0.0):.3f}") print(f"\nRANKING METRICS:") - print(f" Mean Average Precision: {avg_metrics.get('average_precision', 0.0):.3f}") + print( + f" Mean Average Precision: {avg_metrics.get('average_precision', 0.0):.3f}") print(f"\nRETRIEVAL STATISTICS:") - print(f" Avg Sub-queries: {avg_metrics.get('num_sub_queries', 0.0):.1f}") - print(f" Avg Passages Retrieved: {avg_metrics.get('retrieved_passages_count', 0.0):.1f}") - print(f" Avg Unique Docs (N): {avg_metrics.get('retrieved_docs_count', 0.0):.1f}") + print( + f" Avg Sub-queries: {avg_metrics.get('num_sub_queries', 0.0):.1f}") + print( + f" Avg Passages Retrieved: {avg_metrics.get('retrieved_passages_count', 0.0):.1f}") + print( + f" Avg Unique Docs (N): {avg_metrics.get('retrieved_docs_count', 0.0):.1f}") print(f"\nTIMING:") - print(f" Avg Decomposition Time: {avg_metrics.get('decomposition_time', 0.0)*1000:.1f}ms") - print(f" Avg Retrieval Time: {avg_metrics.get('retrieval_time', 0.0)*1000:.1f}ms") + print( + f" Avg Decomposition Time: {avg_metrics.get('decomposition_time', 0.0)*1000:.1f}ms") + print( + f" Avg Retrieval Time: {avg_metrics.get('retrieval_time', 0.0)*1000:.1f}ms") if avg_metrics.get('reranking_time', 0.0) > 0: - print(f" Avg Reranking Time: {avg_metrics.get('reranking_time', 0.0)*1000:.1f}ms") - print(f" Avg Total Time: {avg_metrics.get('total_time', 0.0)*1000:.1f}ms") - print(f" Avg Query LLM Time: {avg_metrics.get('query_llm_time', 0.0)*1000:.1f}ms") + print( + f" Avg Reranking Time: {avg_metrics.get('reranking_time', 0.0)*1000:.1f}ms") + print( + f" Avg Total Time: {avg_metrics.get('total_time', 0.0)*1000:.1f}ms") + print( + f" Avg Query LLM Time: {avg_metrics.get('query_llm_time', 0.0)*1000:.1f}ms") eval_wall_time = time.perf_counter() - eval_wall_start qps = valid_queries / eval_wall_time if eval_wall_time > 0 else 0.0 print(f" Total Wall Time: {eval_wall_time:.1f}s") print(f" Throughput: {qps:.3f} queries/sec") print(f"{'='*80}\n") - + # Print detailed analysis if requested if detailed_analysis and all_query_metrics: from evaluation import _print_detailed_analysis _print_detailed_analysis(df, all_query_metrics, valid_queries) - + avg_metrics['_per_query_results'] = all_results return avg_metrics else: @@ -1738,26 +1861,26 @@ def process_single_query(item): if __name__ == "__main__": args = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, description="Multi-shot retrieval with query decomposition") - + # Add all standard parameters add_all_args(args) - + # Add multi-shot specific parameters args.add_argument('--max-sub-queries', type=int, default=3, - help='Maximum number of sub-queries to generate (default: 3)') + help='Maximum number of sub-queries to generate (default: 3)') args.add_argument('--reasoning', type=str, default='medium', - choices=['low', 'medium', 'high'], - help='LLM reasoning level for query decomposition (default: medium)') + choices=['low', 'medium', 'high'], + help='LLM reasoning level for query decomposition (default: medium)') args.add_argument('--max-iterations', type=int, default=10, - help='Maximum number of retrieval iterations (default: 10)') + help='Maximum number of retrieval iterations (default: 10)') args.add_argument('--num-workers', type=int, default=1, - help='Number of parallel query workers (default: 1, sequential)') + help='Number of parallel query workers (default: 1, sequential)') args.add_argument('--temperature', type=float, default=1.0, - help='LLM sampling temperature (default: 1.0)') + help='LLM sampling temperature (default: 1.0)') args.add_argument('--max-retries', type=int, default=5, - help='Max retries for LLM calls on rate limit/server errors (default: 5)') + help='Max retries for LLM calls on rate limit/server errors (default: 5)') args.add_argument('--output-dir', type=str, default='.', - help='Directory for output files (default: current directory)') + help='Directory for output files (default: current directory)') # Special handling for --eval argument for action in args._actions: @@ -1765,12 +1888,12 @@ def process_single_query(item): action.type = lambda x: int(x) if x.isdigit() else True action.const = True break - + args = args.parse_args() - + # Set deterministic seeds set_deterministic_seeds(args.seed) - + # Setup LLM configuration with auto-detection llm_config = setup_llm_config(args) llm_config['temperature'] = args.temperature @@ -1797,13 +1920,15 @@ def process_single_query(item): # Setup device-specific environment device_config = get_device_config() print(f"Device Config: {device_config}") - + # Initialize database if args.database is None: args.database = VectorDB.get_default_db_name() - db_file_path = args.database if args.database.endswith('.db') else f"{args.database}.db" - db_base_name = args.database.replace('.db', '') if args.database.endswith('.db') else args.database + db_file_path = args.database if args.database.endswith( + '.db') else f"{args.database}.db" + db_base_name = args.database.replace( + '.db', '') if args.database.endswith('.db') else args.database rag_db = VectorDB( retriever_model=args.retriever_model, @@ -1816,14 +1941,15 @@ def process_single_query(item): reranker_device=args.reranker_device, benchmark=args.benchmark ) - + # Load database if os.path.exists(db_file_path): print(f"Loading existing database from {db_file_path}") rag_db.from_serialized(db_file_path) else: - raise ValueError(f"Database not found: {db_file_path}. Please create it first using single_shot_retrieval.py") - + raise ValueError( + f"Database not found: {db_file_path}. Please create it first using single_shot_retrieval.py") + # Build strategy parameters strategy_params = {"max_results": args.max_results} if args.retrieval_strategy == "top_p": @@ -1834,7 +1960,9 @@ def process_single_query(item): # Initialize LLM logger with incremental writing experiment_start_time = datetime.now() os.makedirs(args.output_dir, exist_ok=True) - log_filename = os.path.join(args.output_dir, f"llm_logs_multi_shot_{experiment_start_time.strftime('%Y%m%d_%H%M%S')}.json") + log_filename = os.path.join( + args.output_dir, + f"llm_logs_multi_shot_{experiment_start_time.strftime('%Y%m%d_%H%M%S')}.json") # Determine chunk size from database name chunk_size = 768 # default @@ -1848,7 +1976,8 @@ def process_single_query(item): llm_logger = LLMLogger( output_file=log_filename, experiment_metadata={ - "experiment_name": f"multi_shot_{db_base_name}_n{{queries}}", # Will be updated + # Will be updated + "experiment_name": f"multi_shot_{db_base_name}_n{{queries}}", "timestamp_start": experiment_start_time.isoformat(), "timestamp_end": "in_progress", "retrieval_mode": "multi_shot", @@ -1867,7 +1996,6 @@ def process_single_query(item): ) print(f"LLM logs will be written incrementally to: {log_filename}") - # Setup threading infrastructure if parallel workers requested if args.num_workers > 1: print(f"Enabling parallel execution with {args.num_workers} workers") @@ -1875,7 +2003,9 @@ def process_single_query(item): # Run evaluation or single query if args.eval: - max_queries = args.eval if isinstance(args.eval, int) and not isinstance(args.eval, bool) and args.eval > 0 else None + max_queries = args.eval if isinstance( + args.eval, int) and not isinstance( + args.eval, bool) and args.eval > 0 else None metrics = run_multi_shot_evaluation( rag_db, args.dataset, @@ -1894,7 +2024,7 @@ def process_single_query(item): num_workers=args.num_workers, **strategy_params ) - + # Save results per_query_results = metrics.pop('_per_query_results', []) results_data = { @@ -1904,7 +2034,7 @@ def process_single_query(item): "metrics": metrics, "results": per_query_results, } - + result_path = os.path.join(args.output_dir, "result_multi_shot.json") with open(result_path, "w") as f: json.dump(results_data, f, indent=2) @@ -1981,5 +2111,6 @@ def process_single_query(item): rq = rag_db._reranker_queue if rq is not None: avg_ms = rq.total_latency_ms / rq.total_requests if rq.total_requests else 0 - print(f"Reranker stats: {rq.total_requests} requests, {rq.total_documents} docs, {rq.total_latency_ms:.0f}ms total, {avg_ms:.1f}ms/request avg") + print( + f"Reranker stats: {rq.total_requests} requests, {rq.total_documents} docs, {rq.total_latency_ms:.0f}ms total, {avg_ms:.1f}ms/request avg") rag_db.shutdown_reranker() diff --git a/e2e-rag/oracle_single_shot.py b/e2e-rag/oracle_single_shot.py index 1a39f7fd65..1496daf1c5 100644 --- a/e2e-rag/oracle_single_shot.py +++ b/e2e-rag/oracle_single_shot.py @@ -46,14 +46,16 @@ DEFAULT_CHECKPOINT_FILE = "oracle_checkpoint.pkl" DEFAULT_SERVICE_URL = "http://localhost:8123/v1/chat/completions" -#DEFAULT_MODEL_NAME = "/mnt/weka/data/pytorch/llama3.3/Meta-Llama-3.3-70B-Instruct" -#DEFAULT_MODEL_NAME = "/mnt/weka/data/pytorch/llama3.1/Meta-Llama-3.1-405B-Instruct-v2" +# DEFAULT_MODEL_NAME = "/mnt/weka/data/pytorch/llama3.3/Meta-Llama-3.3-70B-Instruct" +# DEFAULT_MODEL_NAME = "/mnt/weka/data/pytorch/llama3.1/Meta-Llama-3.1-405B-Instruct-v2" DEFAULT_MODEL_NAME = "/model/gpt-oss-120b-mxfp4" DEFAULT_BATCH_SIZE = 1 DEFAULT_TIMEOUT = 2400 # For reasoning model, it should be large enough -DEFAULT_MAX_TOKENS = 10*1024 -MAX_TOTAL_CHARS = 400000 # total char budget split across all docs per query (~100K tokens @ 4 chars/token) +DEFAULT_MAX_TOKENS = 10 * 1024 +# total char budget split across all docs per query (~100K tokens @ 4 +# chars/token) +MAX_TOTAL_CHARS = 400000 # Global cache for URL to filename mapping _url_to_file_cache: Optional[Dict[str, Path]] = None @@ -145,27 +147,28 @@ def build_url_to_file_cache(wiki_dir: Path) -> Dict[str, Path]: Uses JSON metadata files to get accurate URL mapping. """ cache = {} - + print(f"Building URL cache from {wiki_dir}...") json_files = list(wiki_dir.glob("*.json")) - + for json_file in json_files: try: with open(json_file, 'r') as f: data = json.load(f) - + # Get URL without fragment url = data.get('url', '') source_url = data.get('source_url', '') - + # Remove fragment from source_url if present if '#' in source_url: source_url = source_url.split('#')[0] - + # Get corresponding .txt file txt_file = json_file.with_suffix('.txt') if txt_file.exists(): - # Map both url and source_url (without fragment) to the file + # Map both url and source_url (without fragment) to the + # file if url: cache[url] = txt_file if source_url and source_url != url: @@ -173,7 +176,7 @@ def build_url_to_file_cache(wiki_dir: Path) -> Dict[str, Path]: except Exception as e: # Skip files with errors continue - + print(f"Cached {len(cache)} URL mappings from {len(json_files)} JSON files") return cache @@ -184,27 +187,28 @@ def find_wiki_article(url: str, wiki_dir: Path) -> Optional[Path]: Uses JSON metadata for accurate matching. """ global _url_to_file_cache - + if not url or "wikipedia.org/wiki/" not in url: return None - + # Build cache on first call if _url_to_file_cache is None: _url_to_file_cache = build_url_to_file_cache(wiki_dir) - + # Remove fragment from URL if present url_no_fragment = url.split('#')[0] - + # Look up in cache if url in _url_to_file_cache: return _url_to_file_cache[url] elif url_no_fragment in _url_to_file_cache: return _url_to_file_cache[url_no_fragment] - + return None -def load_wiki_articles(wiki_urls: List[str], wiki_dir: Path, max_chars: int = MAX_TOTAL_CHARS) -> Tuple[List[str], List[str], List[int]]: +def load_wiki_articles(wiki_urls: List[str], wiki_dir: Path, + max_chars: int = MAX_TOTAL_CHARS) -> Tuple[List[str], List[str], List[int]]: """ Load Wikipedia articles from wiki_articles folder. max_chars is the TOTAL character budget shared across all docs for this query. @@ -225,7 +229,8 @@ def load_wiki_articles(wiki_urls: List[str], wiki_dir: Path, max_chars: int = MA if file_path and file_path.exists(): try: - content = file_path.read_text(encoding="utf-8", errors="ignore") + content = file_path.read_text( + encoding="utf-8", errors="ignore") truncated = content[:per_doc_limit] documents.append(truncated) file_paths.append(str(file_path)) @@ -239,14 +244,15 @@ def load_wiki_articles(wiki_urls: List[str], wiki_dir: Path, max_chars: int = MA documents.append("") file_paths.append("") doc_lengths.append(0) - + return documents, file_paths, doc_lengths -def generate_llm_answer(query: str, documents: List[str], urls: List[str], llm_config: Dict) -> str: +def generate_llm_answer( + query: str, documents: List[str], urls: List[str], llm_config: Dict) -> str: """ Generate LLM answer using the provided documents as context. - + This function is adapted from single_shot_retrieval.py _generate_llm_answer """ context_parts = [] @@ -255,15 +261,16 @@ def generate_llm_answer(query: str, documents: List[str], urls: List[str], llm_c source = url or "Unknown source" snippet = doc.strip() context_parts.append(f"[{idx}] Source: {source}\n{snippet}") - - evidence_block = "\n\n".join(context_parts) if context_parts else "No supporting documents were retrieved." - + + evidence_block = "\n\n".join( + context_parts) if context_parts else "No supporting documents were retrieved." + user_prompt = ( "Answer the question using only the provided evidence." " Respond with a single word or short phrase, or 'Unknown' if the evidence is insufficient.\n\n" f"Question:\n{query}\n\nEvidence:\n{evidence_block}" ) - + payload = { "model": llm_config["model_name"], "messages": [ @@ -283,11 +290,14 @@ def generate_llm_answer(query: str, documents: List[str], urls: List[str], llm_c if llm_config.get("reasoning_effort"): payload["reasoning_effort"] = llm_config["reasoning_effort"] - response = requests.post(llm_config["service_url"], json=payload, timeout=llm_config["timeout"]) + response = requests.post( + llm_config["service_url"], + json=payload, + timeout=llm_config["timeout"]) response.raise_for_status() data = response.json() - #from pprint import pprint - #pprint(data, indent=4) + # from pprint import pprint + # pprint(data, indent=4) return data["choices"][0]["message"]["content"].strip() @@ -315,7 +325,7 @@ def parse_wiki_links(wiki_links_str: str) -> List[str]: try: # Use ast.literal_eval to safely parse the string as a Python literal return ast.literal_eval(wiki_links_str) - except: + except BaseException: return [] @@ -344,7 +354,8 @@ def process_single( } try: - documents, file_paths, doc_lengths = load_wiki_articles(wiki_urls, wiki_dir) + documents, file_paths, doc_lengths = load_wiki_articles( + wiki_urls, wiki_dir) result["wiki_file_paths"] = str(file_paths) result["doc_lengths"] = str(doc_lengths) result["total_doc_length"] = sum(doc_lengths) @@ -354,9 +365,11 @@ def process_single( result["num_missing_docs"] = missing_count if missing_count > 0: - print(f" Query {idx}: {missing_count}/{len(wiki_urls)} documents missing") + print( + f" Query {idx}: {missing_count}/{len(wiki_urls)} documents missing") - llm_answer = generate_llm_answer(query, documents, wiki_urls, llm_config) + llm_answer = generate_llm_answer( + query, documents, wiki_urls, llm_config) result["llm_answer"] = llm_answer result["success"] = True @@ -386,7 +399,14 @@ def process_batch( futures_map = {} with ThreadPoolExecutor(max_workers=len(batch_data)) as executor: for idx, query, ground_truth, wiki_urls in batch_data: - future = executor.submit(process_single, idx, query, ground_truth, wiki_urls, wiki_dir, llm_config) + future = executor.submit( + process_single, + idx, + query, + ground_truth, + wiki_urls, + wiki_dir, + llm_config) futures_map[future] = idx results_map = {} @@ -400,38 +420,40 @@ def process_batch( def main(): args = parse_args() - + # Setup paths dataset_path = Path(args.dataset) wiki_dir = Path(args.wiki_articles_dir) checkpoint_file = Path(args.checkpoint_file) - + # Validate paths if not dataset_path.exists(): raise FileNotFoundError(f"Dataset not found: {dataset_path}") if not wiki_dir.exists(): - raise FileNotFoundError(f"Wiki articles directory not found: {wiki_dir}") + raise FileNotFoundError( + f"Wiki articles directory not found: {wiki_dir}") # Pre-build URL cache once before threads start global _url_to_file_cache _url_to_file_cache = build_url_to_file_cache(wiki_dir) - + # Load dataset print(f"Loading dataset from {dataset_path}...") df = pd.read_csv(dataset_path, sep="\t") - + # Apply max_queries limit if specified if args.max_queries: df = df.head(args.max_queries) - + print(f"Total queries in dataset: {len(df)}") - + # Load checkpoint checkpoint_df = load_checkpoint(checkpoint_file) - processed_indices = set(checkpoint_df["index"].tolist()) if not checkpoint_df.empty else set() - + processed_indices = set( + checkpoint_df["index"].tolist()) if not checkpoint_df.empty else set() + print(f"Already processed: {len(processed_indices)} queries") - + # LLM configuration llm_config = { "service_url": args.service_url, @@ -441,16 +463,16 @@ def main(): "enable_thinking": args.enable_thinking, "reasoning_effort": args.reasoning_effort, } - + # Determine which queries to process if args.retry_failed: # Retry only failed queries if not checkpoint_df.empty: failed_df = checkpoint_df[checkpoint_df["success"] == False] - queries_to_process = [(int(row["index"]), df.iloc[int(row["index"])]["Prompt"], + queries_to_process = [(int(row["index"]), df.iloc[int(row["index"])]["Prompt"], df.iloc[int(row["index"])]["Answer"], parse_wiki_links(df.iloc[int(row["index"])]["wiki_links"])) - for _, row in failed_df.iterrows()] + for _, row in failed_df.iterrows()] print(f"Retrying {len(queries_to_process)} failed queries...") else: queries_to_process = [] @@ -460,85 +482,93 @@ def main(): for idx, row in df.iterrows(): if idx not in processed_indices: wiki_urls = parse_wiki_links(row["wiki_links"]) - queries_to_process.append((idx, row["Prompt"], row["Answer"], wiki_urls)) - + queries_to_process.append( + (idx, row["Prompt"], row["Answer"], wiki_urls)) + print(f"Processing {len(queries_to_process)} new queries...") - + if not queries_to_process: # If no new queries and all done, check for failed ones if not args.retry_failed: - failed_count = len(checkpoint_df[checkpoint_df["success"] == False]) if not checkpoint_df.empty else 0 + failed_count = len( + checkpoint_df[checkpoint_df["success"] == False]) if not checkpoint_df.empty else 0 if failed_count > 0: - print(f"\nAll new queries processed. {failed_count} queries failed.") + print( + f"\nAll new queries processed. {failed_count} queries failed.") print("Run with --retry-failed to retry failed queries.") else: print("\nAll queries successfully processed!") else: print("No failed queries to retry!") return - + # Process in batches batch_size = args.batch_size total_batches = (len(queries_to_process) + batch_size - 1) // batch_size - + print(f"Batch size: {batch_size}") print(f"Total batches: {total_batches}") print(f"Service URL: {llm_config['service_url']}") print(f"Model: {llm_config['model_name']}\n") - + for batch_idx in range(total_batches): start_idx = batch_idx * batch_size end_idx = min(start_idx + batch_size, len(queries_to_process)) batch = queries_to_process[start_idx:end_idx] - + print(f"Processing batch {batch_idx + 1}/{total_batches} " f"(queries {start_idx + 1}-{end_idx})...") - + # Process batch batch_results = process_batch(batch, wiki_dir, llm_config) - + # Convert batch results to DataFrame batch_df = pd.DataFrame(batch_results) - + # Update checkpoint if args.retry_failed: # For retry, update existing results # Remove old entries for these indices indices_to_update = batch_df["index"].tolist() - checkpoint_df = checkpoint_df[~checkpoint_df["index"].isin(indices_to_update)] + checkpoint_df = checkpoint_df[~checkpoint_df["index"].isin( + indices_to_update)] # Append new results - checkpoint_df = pd.concat([checkpoint_df, batch_df], ignore_index=True) + checkpoint_df = pd.concat( + [checkpoint_df, batch_df], ignore_index=True) else: # For new queries, append results - checkpoint_df = pd.concat([checkpoint_df, batch_df], ignore_index=True) - + checkpoint_df = pd.concat( + [checkpoint_df, batch_df], ignore_index=True) + # Sort by index for consistency - checkpoint_df = checkpoint_df.sort_values("index").reset_index(drop=True) - + checkpoint_df = checkpoint_df.sort_values( + "index").reset_index(drop=True) + # Save checkpoint after each batch save_checkpoint(checkpoint_file, checkpoint_df) print(f" Checkpoint saved to {checkpoint_file}") - + # Show batch statistics success_count = batch_df["success"].sum() print(f" Batch success rate: {success_count}/{len(batch_results)}\n") - + # Final statistics - print("\n" + "="*60) + print("\n" + "=" * 60) print("Processing complete!") - print("="*60) - + print("=" * 60) + total_processed = len(checkpoint_df) total_success = checkpoint_df["success"].sum() total_failed = total_processed - total_success - + print(f"Total queries processed: {total_processed}") print(f"Successful: {total_success}") print(f"Failed: {total_failed}") - + if total_failed > 0: - print(f"\nRun with --retry-failed to retry {total_failed} failed queries.") - + print( + f"\nRun with --retry-failed to retry {total_failed} failed queries.") + print(f"\nResults saved to: {checkpoint_file}") diff --git a/e2e-rag/params.py b/e2e-rag/params.py index b59e573937..7333feac5e 100644 --- a/e2e-rag/params.py +++ b/e2e-rag/params.py @@ -22,7 +22,7 @@ Usage: from params import add_all_args, add_common_args, add_retrieval_args - + parser = argparse.ArgumentParser() add_all_args(parser) # Add all parameters # OR @@ -37,9 +37,12 @@ # ============================================================================ # Parameter Definitions # ============================================================================ + + class ParamDef: """Parameter definition with metadata.""" - def __init__(self, + + def __init__(self, name: str, arg_names: List[str], type: type, @@ -76,35 +79,36 @@ def __init__(self, self.category = category self.applies_to = applies_to or ["both"] self.optuna_suggest = optuna_suggest - + def add_to_parser(self, parser: argparse.ArgumentParser): """Add this parameter to an argument parser.""" kwargs = { 'help': self.help, 'default': self.default, } - + if self.action: kwargs['action'] = self.action else: kwargs['type'] = self.type - + if self.choices: kwargs['choices'] = self.choices - + if self.nargs: kwargs['nargs'] = self.nargs - + parser.add_argument(*self.arg_names, **kwargs) - + def suggest_value(self, trial): """Suggest a value for Optuna trial.""" if not self.optuna_suggest: - raise ValueError(f"No optuna_suggest config for parameter {self.name}") - + raise ValueError( + f"No optuna_suggest config for parameter {self.name}") + config = self.optuna_suggest suggest_type = config['type'] - + if suggest_type == 'float': return trial.suggest_float( self.name, @@ -477,7 +481,12 @@ def suggest_value(self, trial): choices=["fixed_k", "top_p", "relative"], category="strategy", applies_to=["both"], - optuna_suggest={'type': 'categorical', 'choices': ["fixed_k", "top_p", "relative"]} + optuna_suggest={ + 'type': 'categorical', + 'choices': [ + "fixed_k", + "top_p", + "relative"]} ), ParamDef( name="top_k_retriever", @@ -557,8 +566,10 @@ def suggest_value(self, trial): } # Method-specific parameters -BM25_METHOD_PARAMS = [p for p in ALL_PARAMS if "bm25" in p.applies_to or "both" in p.applies_to] -VECTOR_METHOD_PARAMS = [p for p in ALL_PARAMS if "vector" in p.applies_to or "both" in p.applies_to] +BM25_METHOD_PARAMS = [ + p for p in ALL_PARAMS if "bm25" in p.applies_to or "both" in p.applies_to] +VECTOR_METHOD_PARAMS = [ + p for p in ALL_PARAMS if "vector" in p.applies_to or "both" in p.applies_to] # Optimizable parameters (those with optuna_suggest defined) OPTIMIZABLE_PARAMS = [p for p in ALL_PARAMS if p.optuna_suggest is not None] @@ -568,24 +579,26 @@ def suggest_value(self, trial): # Helper Functions # ============================================================================ + def add_common_args(parser: argparse.ArgumentParser): """ Add common script parameters (ingest, database, query, etc.) - + Args: parser: ArgumentParser to add arguments to """ for param in COMMON_PARAMS: param.add_to_parser(parser) -def add_retrieval_args(parser: argparse.ArgumentParser, + +def add_retrieval_args(parser: argparse.ArgumentParser, method: Optional[str] = None, categories: Optional[List[str]] = None): """ Add retrieval parameters to an argument parser. Includes: General, BM25, Vector, Strategy, and Reranking parameters. Does NOT include common script parameters (use add_common_args for those). - + Args: parser: ArgumentParser to add arguments to method: Filter by method ('bm25', 'vector', or None for all) @@ -593,24 +606,26 @@ def add_retrieval_args(parser: argparse.ArgumentParser, """ # Get non-common params params_to_add = [p for p in ALL_PARAMS if p.category != "common"] - + # Filter by method if method: - params_to_add = [p for p in params_to_add - if method in p.applies_to or "both" in p.applies_to] - + params_to_add = [p for p in params_to_add + if method in p.applies_to or "both" in p.applies_to] + # Filter by category if categories: params_to_add = [p for p in params_to_add if p.category in categories] - + # Add to parser for param in params_to_add: param.add_to_parser(parser) -def add_all_args(parser: argparse.ArgumentParser, method: Optional[str] = None): + +def add_all_args(parser: argparse.ArgumentParser, + method: Optional[str] = None): """ Add all parameters (common + retrieval) to an argument parser. - + Args: parser: ArgumentParser to add arguments to method: Filter by method ('bm25', 'vector', or None for all) @@ -618,48 +633,51 @@ def add_all_args(parser: argparse.ArgumentParser, method: Optional[str] = None): add_common_args(parser) add_retrieval_args(parser, method=method) + def get_optimizable_params(method: Optional[str] = None) -> List[ParamDef]: """ Get list of parameters that can be optimized with Optuna. - + Args: method: Filter by method ('bm25', 'vector', or None for all) - + Returns: List of ParamDef objects """ params = OPTIMIZABLE_PARAMS - + if method: - params = [p for p in params - if method in p.applies_to or "both" in p.applies_to] - + params = [p for p in params + if method in p.applies_to or "both" in p.applies_to] + return params + def suggest_param(trial, param_name: str) -> Any: """ Suggest a parameter value for Optuna trial. - + Args: trial: Optuna trial object param_name: Parameter name - + Returns: Suggested value """ if param_name not in PARAM_BY_NAME: raise ValueError(f"Unknown parameter: {param_name}") - + param = PARAM_BY_NAME[param_name] return param.suggest_value(trial) + def get_default_params(method: str) -> Dict[str, Any]: """ Get default parameter values for a method. - + Args: method: 'bm25' or 'vector' - + Returns: Dictionary of parameter name -> default value """ @@ -669,34 +687,36 @@ def get_default_params(method: str) -> Dict[str, Any]: params = VECTOR_METHOD_PARAMS else: params = ALL_PARAMS - + return {p.name: p.default for p in params} -def format_params_for_cli(params: Dict[str, Any], skip_defaults: bool = True) -> List[str]: + +def format_params_for_cli( + params: Dict[str, Any], skip_defaults: bool = True) -> List[str]: """ Format parameter dictionary as CLI arguments. - + Args: params: Dictionary of parameter name -> value skip_defaults: If True, skip parameters that match their default values - + Returns: List of CLI argument strings """ args = [] - + for name, value in params.items(): if name not in PARAM_BY_NAME: continue - + param_def = PARAM_BY_NAME[name] - + # Skip if value matches default (when skip_defaults=True) if skip_defaults and value == param_def.default: continue - + cli_arg = param_def.arg_names[0] - + if param_def.action == "store_true": if value: args.append(cli_arg) @@ -705,40 +725,42 @@ def format_params_for_cli(params: Dict[str, Any], skip_defaults: bool = True) -> args.append(cli_arg) elif value is not None: args.extend([cli_arg, str(value)]) - + return args + def print_param_info(method: Optional[str] = None): """Print parameter information grouped by category.""" - + if method == "bm25": params = BM25_METHOD_PARAMS elif method == "vector": params = VECTOR_METHOD_PARAMS else: params = ALL_PARAMS - + # Group by category by_category = {} for param in params: if param.category not in by_category: by_category[param.category] = [] by_category[param.category].append(param) - + # Print for category in ["common", "general", "vector", "strategy", "reranking"]: if category not in by_category: continue - + print(f"\n{category.upper()} Parameters:") print("=" * 60) - + for param in by_category[category]: opt_marker = " [optimizable]" if param.optuna_suggest else "" print(f" {param.name}{opt_marker}") print(f" CLI: {', '.join(param.arg_names)}") print(f" Default: {param.default}") - print(f" Help: {param.help[:80]}..." if len(param.help) > 80 else f" Help: {param.help}") + print(f" Help: {param.help[:80]}..." if len( + param.help) > 80 else f" Help: {param.help}") if param.choices: print(f" Choices: {param.choices}") @@ -746,9 +768,10 @@ def print_param_info(method: Optional[str] = None): # Main (for testing/documentation) # ============================================================================ + if __name__ == "__main__": import sys - + if len(sys.argv) > 1 and sys.argv[1] == "list": method = sys.argv[2] if len(sys.argv) > 2 else None print_param_info(method) diff --git a/e2e-rag/perf_test_cache.py b/e2e-rag/perf_test_cache.py index 09be338e8f..0a29451071 100644 --- a/e2e-rag/perf_test_cache.py +++ b/e2e-rag/perf_test_cache.py @@ -89,7 +89,8 @@ def _build_index(self): print(f" Built cache index with {len(self.cache)} LLM responses") - def get_response(self, query_id: str, component: str, hop_count: int) -> Optional[str]: + def get_response(self, query_id: str, component: str, + hop_count: int) -> Optional[str]: """ Retrieve cached LLM response. @@ -104,7 +105,8 @@ def get_response(self, query_id: str, component: str, hop_count: int) -> Optiona key = (str(query_id), component, hop_count) return self.cache.get(key) - def has_response(self, query_id: str, component: str, hop_count: int) -> bool: + def has_response(self, query_id: str, component: str, + hop_count: int) -> bool: """ Check if cached response exists. diff --git a/e2e-rag/read_docs.py b/e2e-rag/read_docs.py index 12ba80272f..9859e6dd1a 100644 --- a/e2e-rag/read_docs.py +++ b/e2e-rag/read_docs.py @@ -47,12 +47,12 @@ class BaseDocumentExtractor(ABC): """Base class for document text extractors.""" - + @abstractmethod def extract_text(self, file_path: str) -> Optional[str]: """Extract text from a document file.""" pass - + @abstractmethod def get_supported_extensions(self) -> List[str]: """Return list of supported file extensions.""" @@ -61,7 +61,7 @@ def get_supported_extensions(self) -> List[str]: class PDFExtractor(BaseDocumentExtractor): """Extract text from PDF files using PyMuPDF (fitz).""" - + def extract_text(self, file_path: str) -> Optional[str]: """Extract text from a single PDF file.""" try: @@ -74,81 +74,83 @@ def extract_text(self, file_path: str) -> Optional[str]: doc.close() return "\n".join(extracted_text) - + except Exception as e: print(f"Error processing PDF {file_path}: {e}") return None - + def get_supported_extensions(self) -> List[str]: return ['.pdf'] class HTMLExtractor(BaseDocumentExtractor): """Extract text from HTML files using BeautifulSoup with focus on retrieval quality.""" - - def __init__(self, preserve_tables: bool = True, preserve_lists: bool = True, + + def __init__(self, preserve_tables: bool = True, preserve_lists: bool = True, text_boundary: str = "sentence"): """ Initialize HTML extractor with configurable options. - + Args: preserve_tables: Whether to preserve table structure - preserve_lists: Whether to preserve list structure + preserve_lists: Whether to preserve list structure text_boundary: Text boundary optimization - "sentence" (default), "word", or "none" """ if BeautifulSoup is None: - raise ImportError("BeautifulSoup is required for HTML processing. Install with: pip install beautifulsoup4") - + raise ImportError( + "BeautifulSoup is required for HTML processing. Install with: pip install beautifulsoup4") + self.preserve_tables = preserve_tables self.preserve_lists = preserve_lists self.text_boundary = text_boundary - + if text_boundary not in ["sentence", "word", "none"]: - raise ValueError("text_boundary must be 'sentence', 'word', or 'none'") - + raise ValueError( + "text_boundary must be 'sentence', 'word', or 'none'") + def extract_text(self, file_path: str) -> Optional[str]: """Extract text from a single HTML file.""" try: with open(file_path, 'r', encoding='utf-8') as f: html_content = f.read() - + return self.extract_text_from_html(html_content) - + except Exception as e: print(f"Error processing HTML {file_path}: {e}") return None - + def extract_text_from_html(self, html_content: str) -> str: """Extract clean text optimized for retrieval systems.""" # Use lxml parser for speed (fallback to html.parser if not available) try: soup = BeautifulSoup(html_content, 'lxml') - except: + except BaseException: soup = BeautifulSoup(html_content, 'html.parser') - + # Remove noise elements completely for element in soup(['script', 'style', 'nav', 'header', 'footer']): element.decompose() - + # Remove Wikipedia-specific metadata and navigation self._remove_wikipedia_metadata(soup) - + # Extract main content using priority order main_content = self._find_main_content(soup) - + # Get plain text with sentence separation if main_content: text = main_content.get_text(separator=' ', strip=True) else: text = soup.get_text(separator=' ', strip=True) - + # Clean and normalize the text return self._clean_text(text) - + def _extract_from_element(self, element) -> List[str]: """Extract text from an HTML element, preserving structure.""" text_parts = [] - + if isinstance(element, NavigableString): text = str(element).strip() if text: @@ -186,9 +188,9 @@ def _extract_from_element(self, element) -> List[str]: text = element.get_text().strip() if text: text_parts.append(text) - + return text_parts - + def _extract_table_text(self, table) -> str: """Extract text from a table element.""" rows = [] @@ -200,7 +202,7 @@ def _extract_table_text(self, table) -> str: if cells: rows.append(' | '.join(cells)) return '\n'.join(rows) - + def _extract_list_text(self, list_elem) -> str: """Extract text from a list element.""" items = [] @@ -210,75 +212,79 @@ def _extract_list_text(self, list_elem) -> str: prefix = '- ' if list_elem.name == 'ul' else f"{len(items) + 1}. " items.append(f"{prefix}{item_text}") return '\n'.join(items) - + def _clean_text(self, text: str) -> str: """Clean and normalize extracted text with configurable boundary optimization.""" # Basic normalization import unicodedata text = unicodedata.normalize('NFKC', text) - + # Replace various whitespace characters with standard space text = re.sub(r'[\u00A0\u2000-\u200B\u2028\u2029]', ' ', text) - + # Clean up whitespace text = text.strip().replace('\r', '\n') text = re.sub(r' +', ' ', text) # Multiple spaces -> single space - text = re.sub(r'\n+', '\n', text) # Multiple newlines -> single newline - + # Multiple newlines -> single newline + text = re.sub(r'\n+', '\n', text) + # Remove empty lines and extra spacing lines = [line.strip() for line in text.split('\n') if line.strip()] text = '\n'.join(lines) - + # Apply boundary optimization based on setting if self.text_boundary == "sentence": text = self._optimize_sentence_boundaries(text) elif self.text_boundary == "word": text = self._optimize_word_boundaries(text) # "none" - no boundary optimization - + return text - + def _optimize_sentence_boundaries(self, text: str) -> str: """Optimize text for sentence-level splitting and retrieval.""" # Add space after sentence endings if missing text = re.sub(r'([.!?])([A-Z])', r'\1 \2', text) - + # Handle common abbreviations that shouldn't split sentences # (e.g., "Mr.", "Dr.", "etc.", "U.S.") abbrev_pattern = r'\b(Mr|Mrs|Dr|Prof|etc|vs|Inc|Ltd|Corp|U\.S|U\.K|E\.g|I\.e)\.(\s+)([a-z])' text = re.sub(abbrev_pattern, r'\1.\2\3', text, flags=re.IGNORECASE) - + return text - + def _optimize_word_boundaries(self, text: str) -> str: """Optimize text for word-level processing and retrieval.""" # Ensure proper spacing around punctuation for better tokenization text = re.sub(r'([.!?,:;])([A-Za-z])', r'\1 \2', text) - + # Handle hyphenated words - keep them as single tokens text = re.sub(r'(\w+)-\s+(\w+)', r'\1-\2', text) - + # Normalize quotation marks and other punctuation - text = text.replace('"', '"').replace('"', '"') # Smart quotes to regular quotes - text = text.replace(''', "'").replace(''', "'") # Smart apostrophes to regular apostrophes - + text = text.replace( + '"', '"').replace( + '"', '"') # Smart quotes to regular quotes + text = text.replace(''', "'").replace(''', + "'") # Smart apostrophes to regular apostrophes + # Ensure consistent spacing text = re.sub(r'\s+', ' ', text) - + return text - + def _remove_wikipedia_metadata(self, soup): """Remove Wikipedia-specific metadata and navigation elements.""" # Wikipedia-specific noise removal selectors_to_remove = [ # Navigation and interface elements - '#mw-navigation', '.navbox', '.navigation-box', + '#mw-navigation', '.navbox', '.navigation-box', '.ambox', '.tmbox', # Edit links and metadata '.mw-editsection', '.edit-section', '.editlink', # References and citations (keep text but remove citation numbers) 'sup.reference', '.reference', '.citation', - # Disambiguation and hatnotes + # Disambiguation and hatnotes '.hatnote', '.dablink', '.rellink', # Categories and external links boxes '#catlinks', '.catlinks', '.external-links', @@ -287,11 +293,11 @@ def _remove_wikipedia_metadata(self, soup): # Image captions and metadata (keep main text) '.thumbcaption .metadata', '.image-metadata' ] - + for selector in selectors_to_remove: for element in soup.select(selector): element.decompose() - + def _find_main_content(self, soup): """Find the main content area with fallback strategy.""" # Priority order for content detection @@ -304,31 +310,31 @@ def _find_main_content(self, soup): '#content', # Generic content ID 'body' # Last resort ] - + for selector in content_selectors: content = soup.select_one(selector) if content: return content - + # Final fallback return soup - + def get_supported_extensions(self) -> List[str]: return ['.html', '.htm'] class DocumentProcessor: """Unified document processor that handles both PDF and HTML files.""" - - def __init__(self, preserve_tables: bool = True, preserve_lists: bool = True, + + def __init__(self, preserve_tables: bool = True, preserve_lists: bool = True, text_boundary: str = "sentence", benchmark: bool = False, processes: int = 4): """ Initialize document processor. - + Args: preserve_tables: Whether to preserve table structure (HTML only) - preserve_lists: Whether to preserve list structure (HTML only) + preserve_lists: Whether to preserve list structure (HTML only) text_boundary: Text boundary optimization - "sentence" (default), "word", or "none" benchmark: Enable performance monitoring processes: Number of parallel processes for document processing @@ -337,62 +343,64 @@ def __init__(self, preserve_tables: bool = True, preserve_lists: bool = True, self.extractors = { '.pdf': PDFExtractor(), } - + # Only add HTML extractor if BeautifulSoup is available if BeautifulSoup is not None: self.extractors.update({ '.html': HTMLExtractor(preserve_tables, preserve_lists, text_boundary), '.htm': HTMLExtractor(preserve_tables, preserve_lists, text_boundary), }) - + self.url_mapping = {} self.benchmark = benchmark self.monitor = None - + # Store config for worker processes self.preserve_tables = preserve_tables self.preserve_lists = preserve_lists self.text_boundary = text_boundary - + # Initialize monitoring if benchmark mode enabled if self.benchmark: from ingestion_monitor import IngestionMonitor self.monitor = IngestionMonitor() - + def get_supported_extensions(self) -> List[str]: """Get all supported file extensions.""" extensions = [] for extractor in self.extractors.values(): extensions.extend(extractor.get_supported_extensions()) return list(set(extensions)) - + @staticmethod - def process_single_file(args_tuple: Tuple) -> Optional[Tuple[str, str, List[str], str]]: + def process_single_file( + args_tuple: Tuple) -> Optional[Tuple[str, str, List[str], str]]: """ Process a single document file (worker function for multiprocessing). - + Args: - args_tuple: (doc_file_path, output_dir, url_mapping, preserve_tables, + args_tuple: (doc_file_path, output_dir, url_mapping, preserve_tables, preserve_lists, text_boundary, fixed_length, fixed_overlap, max_passage_length, passage_overlap) - + Returns: Tuple of (output_filename, text, passages, original_url) or None if processing failed """ - (doc_file_path, output_dir, url_mapping, preserve_tables, preserve_lists, + (doc_file_path, output_dir, url_mapping, preserve_tables, preserve_lists, text_boundary, fixed_length, fixed_overlap, max_passage_length, passage_overlap) = args_tuple - + doc_file = Path(doc_file_path) file_extension = doc_file.suffix.lower() - + # Create appropriate extractor if file_extension == '.pdf': extractor = PDFExtractor() elif file_extension in ['.html', '.htm'] and BeautifulSoup is not None: - extractor = HTMLExtractor(preserve_tables, preserve_lists, text_boundary) + extractor = HTMLExtractor( + preserve_tables, preserve_lists, text_boundary) else: return None - + # Extract text try: text = extractor.extract_text(str(doc_file)) @@ -401,33 +409,35 @@ def process_single_file(args_tuple: Tuple) -> Optional[Tuple[str, str, List[str] except Exception as e: print(f"Error extracting text from {doc_file}: {e}") return None - + # Split text into passages try: if fixed_length: - passages = split_into_fixed_passages(text, fixed_length, fixed_overlap or 32) + passages = split_into_fixed_passages( + text, fixed_length, fixed_overlap or 32) else: - passages = split_into_passages(text, max_passage_length, passage_overlap) + passages = split_into_passages( + text, max_passage_length, passage_overlap) except Exception as e: print(f"Error splitting text for {doc_file}: {e}") return None - + # Get original URL base_filename = get_base_filename(doc_file.name) original_url = url_mapping.get(base_filename, "") - + # Generate output filename output_filename = doc_file.stem + ".txt" - + return (output_filename, text, passages, original_url, doc_file.name) - + def process_documents(self, input_dir: str, output_dir: str, json_file: Optional[str] = None, - max_passage_length: int = 512, passage_overlap: int = 50, - fixed_length: Optional[int] = None, fixed_overlap: Optional[int] = None, - max_files: Optional[int] = None): + max_passage_length: int = 512, passage_overlap: int = 50, + fixed_length: Optional[int] = None, fixed_overlap: Optional[int] = None, + max_files: Optional[int] = None): """ Process documents in a directory, extracting text and splitting into passages. - + Args: input_dir: Directory containing document files output_dir: Directory to save extracted text files @@ -447,22 +457,23 @@ def process_documents(self, input_dir: str, output_dir: str, json_file: Optional # Find all supported document files supported_extensions = self.get_supported_extensions() document_files = [] - + for ext in supported_extensions: pattern = f"*{ext}" document_files.extend(input_path.glob(pattern)) - + # Sort for consistent processing order document_files = sorted(document_files) - + if not document_files: return - + if max_files: document_files = document_files[:max_files] - - print(f"Processing {len(document_files)} documents with {self.processes} parallel processes...") - + + print( + f"Processing {len(document_files)} documents with {self.processes} parallel processes...") + all_passages = [] passage_id = 0 @@ -472,22 +483,23 @@ def process_documents(self, input_dir: str, output_dir: str, json_file: Optional # Prepare arguments for multiprocessing process_args = [ - (str(doc_file), str(output_path), self.url_mapping, + (str(doc_file), str(output_path), self.url_mapping, self.preserve_tables, self.preserve_lists, self.text_boundary, fixed_length, fixed_overlap, max_passage_length, passage_overlap) for doc_file in document_files ] - + # Process documents in parallel with progress bar with Pool(processes=self.processes) as pool: with tqdm(total=len(document_files), desc="Processing documents") as pbar: - for result in pool.imap(self.process_single_file, process_args): + for result in pool.imap( + self.process_single_file, process_args): if result is None: pbar.update(1) continue - + output_filename, text, passages, original_url, doc_filename = result - + # Save text file output_file_path = output_path / output_filename try: @@ -497,19 +509,19 @@ def process_documents(self, input_dir: str, output_dir: str, json_file: Optional print(f"Error writing {output_file_path}: {e}") pbar.update(1) continue - + # Add passages to collection if JSON output requested if json_file: for passage in passages: passage_metadata = create_passage_metadata( doc_filename, passage_id, original_url=original_url) - + all_passages.append({ **passage_metadata, 'passage': passage, }) passage_id += 1 - + pbar.update(1) # Finalize monitoring and report @@ -538,65 +550,72 @@ def process_documents(self, input_dir: str, output_dir: str, json_file: Optional # Add component-level timing if monitoring was enabled if self.benchmark and self.monitor: for component_name, metrics in self.monitor.components.items(): - result[f'{component_name}_time_seconds'] = round(metrics.duration, 2) - result[f'{component_name}_throughput_mb_per_sec'] = round(metrics.throughput_mb_per_sec, 2) + result[f'{component_name}_time_seconds'] = round( + metrics.duration, 2) + result[f'{component_name}_throughput_mb_per_sec'] = round( + metrics.throughput_mb_per_sec, 2) return result - - def _process_document(self, doc_file: Path, file_extension: str) -> Optional[str]: + + def _process_document(self, doc_file: Path, + file_extension: str) -> Optional[str]: """Process a single document with optional monitoring.""" extractor = self.extractors[file_extension] - + if self.benchmark and self.monitor: - component_name = "html_parsing" if file_extension in ['.html', '.htm'] else "pdf_parsing" + component_name = "html_parsing" if file_extension in [ + '.html', '.htm'] else "pdf_parsing" file_size = doc_file.stat().st_size - with self.monitor.track_component(component_name, input_size_bytes=file_size, - items_count=1, is_pipeline_input=True): + with self.monitor.track_component(component_name, input_size_bytes=file_size, + items_count=1, is_pipeline_input=True): return extractor.extract_text(str(doc_file)) else: return extractor.extract_text(str(doc_file)) - + def _process_text_chunking(self, text: str, fixed_length: Optional[int], fixed_overlap: Optional[int], - max_passage_length: int, passage_overlap: int) -> List[str]: + max_passage_length: int, passage_overlap: int) -> List[str]: """Process text chunking with optional monitoring.""" def chunk_func(): if fixed_length: - return split_into_fixed_passages(text, fixed_length, fixed_overlap or 32) + return split_into_fixed_passages( + text, fixed_length, fixed_overlap or 32) else: - return split_into_passages(text, max_passage_length, passage_overlap) - + return split_into_passages( + text, max_passage_length, passage_overlap) + if self.benchmark and self.monitor: text_size = len(text.encode('utf-8')) - with self.monitor.track_component("text_chunking", input_size_bytes=text_size, - items_count=1, text_only=True) as ctx: + with self.monitor.track_component("text_chunking", input_size_bytes=text_size, + items_count=1, text_only=True) as ctx: passages = chunk_func() ctx.add_text_bytes(text_size) return passages else: return chunk_func() - - def _add_passages_to_collection(self, doc_file: Path, passages: List[str], - all_passages: List[Dict], passage_id: int) -> int: + + def _add_passages_to_collection(self, doc_file: Path, passages: List[str], + all_passages: List[Dict], passage_id: int) -> int: """Add passages to collection and return updated passage_id.""" base_filename = get_base_filename(doc_file.name) original_url = self.url_mapping.get(base_filename, "") - + for passage in passages: passage_metadata = create_passage_metadata( doc_file.name, passage_id, original_url=original_url) - + all_passages.append({ **passage_metadata, 'passage': passage, }) passage_id += 1 - + return passage_id - + def _report_processing_performance(self): """Report processing performance if monitoring is enabled.""" # Suppress the detailed benchmark summary - metrics are tracked internally - # and will be reported by the calling script (measure_indexing_with_chunking.py) + # and will be reported by the calling script + # (measure_indexing_with_chunking.py) pass @@ -605,31 +624,33 @@ def main(): description="Extract text from PDF and HTML files and split into passages", formatter_class=argparse.RawTextHelpFormatter ) - - parser.add_argument("input_dir", help="Directory containing document files (PDF/HTML)") + + parser.add_argument( + "input_dir", + help="Directory containing document files (PDF/HTML)") parser.add_argument("output_dir", help="Directory to save text files") parser.add_argument("--json", help="JSON file to save passage data") - parser.add_argument("--max-length", type=int, default=512, - help="Maximum passage length in characters (default: 512)") + parser.add_argument("--max-length", type=int, default=512, + help="Maximum passage length in characters (default: 512)") parser.add_argument("--overlap", type=int, default=50, - help="Overlap between passages in characters (default: 50)") + help="Overlap between passages in characters (default: 50)") parser.add_argument("--fixed-length", type=int, - help="Use fixed-length passages instead of variable-length") + help="Use fixed-length passages instead of variable-length") parser.add_argument("--fixed-overlap", type=int, default=32, - help="Overlap for fixed-length passages (default: 32)") + help="Overlap for fixed-length passages (default: 32)") parser.add_argument("--max-files", type=int, - help="Maximum number of files to process (for testing)") + help="Maximum number of files to process (for testing)") parser.add_argument("--no-tables", action="store_true", - help="Don't preserve table structure (HTML only)") - parser.add_argument("--no-lists", action="store_true", - help="Don't preserve list structure (HTML only)") - parser.add_argument("--text-boundary", choices=["sentence", "word", "none"], - default="sentence", - help="Text boundary optimization: 'sentence' (default), 'word', or 'none'") + help="Don't preserve table structure (HTML only)") + parser.add_argument("--no-lists", action="store_true", + help="Don't preserve list structure (HTML only)") + parser.add_argument("--text-boundary", choices=["sentence", "word", "none"], + default="sentence", + help="Text boundary optimization: 'sentence' (default), 'word', or 'none'") parser.add_argument("--processes", type=int, default=4, - help="Number of parallel processes for document processing (default: 4)") + help="Number of parallel processes for document processing (default: 4)") parser.add_argument("--benchmark", action="store_true", - help="Enable performance monitoring and detailed component analysis") + help="Enable performance monitoring and detailed component analysis") args = parser.parse_args() @@ -640,7 +661,7 @@ def main(): benchmark=args.benchmark, processes=args.processes ) - + processor.process_documents( input_dir=args.input_dir, output_dir=args.output_dir, diff --git a/e2e-rag/reference_SUT.py b/e2e-rag/reference_SUT.py index d1e3c5c723..9bccb2e6bd 100644 --- a/e2e-rag/reference_SUT.py +++ b/e2e-rag/reference_SUT.py @@ -112,8 +112,10 @@ def __init__( # Performance test mode if hasattr(args, 'perf_test_mode') and args.perf_test_mode: from perf_test_cache import PerfTestCache - log.info(f"Loading performance test cache from {args.perf_test_mode}") - self.llm_config['perf_test_cache'] = PerfTestCache(args.perf_test_mode) + log.info( + f"Loading performance test cache from {args.perf_test_mode}") + self.llm_config['perf_test_cache'] = PerfTestCache( + args.perf_test_mode) else: self.llm_config['perf_test_cache'] = None @@ -130,11 +132,14 @@ def __init__( # Initialize database log.info("Initializing RAG database...") self.rag_db = VectorDB( - retriever_model=args.retriever_model if hasattr(args, 'retriever_model') else 'BAAI/bge-base-en-v1.5', - reranker_model=args.reranker_model if hasattr(args, 'reranker_model') else 'BAAI/bge-reranker-base', + retriever_model=args.retriever_model if hasattr( + args, 'retriever_model') else 'BAAI/bge-base-en-v1.5', + reranker_model=args.reranker_model if hasattr( + args, 'reranker_model') else 'BAAI/bge-reranker-base', device=device, database=db_path.replace('.db', ''), - num_embedding_devices=args.num_embedding_devices if hasattr(args, 'num_embedding_devices') else 1, + num_embedding_devices=args.num_embedding_devices if hasattr( + args, 'num_embedding_devices') else 1, benchmark=args.benchmark if hasattr(args, 'benchmark') else False ) diff --git a/e2e-rag/reference_SUT_datasetup.py b/e2e-rag/reference_SUT_datasetup.py index d23d1d7802..62e81f0f8d 100644 --- a/e2e-rag/reference_SUT_datasetup.py +++ b/e2e-rag/reference_SUT_datasetup.py @@ -117,7 +117,8 @@ def __init__( # Initialize HTML extractor if not HAVE_HTML: - raise RuntimeError("BeautifulSoup required for HTML processing. Install with: pip install beautifulsoup4") + raise RuntimeError( + "BeautifulSoup required for HTML processing. Install with: pip install beautifulsoup4") log.info("Initializing HTML extractor...") self.html_extractor = HTMLExtractor( @@ -206,7 +207,8 @@ def _process_document(self, query_sample): file_path = document_info['file_path'] file_name = document_info['file_name'] - log.info(f"Processing document {sample_id} (QID: {query_id}): {file_name}") + log.info( + f"Processing document {sample_id} (QID: {query_id}): {file_name}") start_time = time.time() success = 0 # 0 = failure, 1 = success @@ -218,7 +220,8 @@ def _process_document(self, query_sample): if not text or len(text.strip()) == 0: # Empty file - create minimal passage with file name - log.warning(f"No text extracted from {file_name}, creating minimal passage") + log.warning( + f"No text extracted from {file_name}, creating minimal passage") text = f"Document: {file_name}" # Continue with text (even if minimal) @@ -232,22 +235,28 @@ def _process_document(self, query_sample): if not passages: # Even after splitting, no passages - this shouldn't happen now # but create one minimal passage as fallback - log.warning(f"No passages created from {file_name}, using text as-is") + log.warning( + f"No passages created from {file_name}, using text as-is") passages = [text] # Step 3: Generate embeddings (parallel, outside lock) - passage_metadata = [{'source': file_name, 'passage_id': i} for i in range(len(passages))] + passage_metadata = [{'source': file_name, 'passage_id': i} + for i in range(len(passages))] - # Generate embeddings WITHOUT holding db_lock (allows parallel embedding generation) - log.info(f"Generating embeddings for {len(passages)} passages from {file_name}") + # Generate embeddings WITHOUT holding db_lock (allows parallel + # embedding generation) + log.info( + f"Generating embeddings for {len(passages)} passages from {file_name}") if self.rag_db._num_embedding_devices > 1: embeddings = self.rag_db._embed_documents_parallel(passages) else: - embeddings = self.rag_db._embedding_model.embed_documents(passages) + embeddings = self.rag_db._embedding_model.embed_documents( + passages) log.info(f"Embeddings generated for {file_name}, adding to index") - # Step 4: Add to index (thread-safe, holds lock only for index update) + # Step 4: Add to index (thread-safe, holds lock only for index + # update) with self.db_lock: # Add embeddings and documents to vector store ids = self.rag_db._vector_store.add_embeddings( @@ -260,7 +269,8 @@ def _process_document(self, query_sample): self.total_passages_indexed += len(passages) success = 1 - log.info(f"Successfully indexed {len(passages)} passages from {file_name}") + log.info( + f"Successfully indexed {len(passages)} passages from {file_name}") # Check if this is the last file to complete with self.completion_lock: @@ -268,7 +278,8 @@ def _process_document(self, query_sample): is_last_file = (self.completed_count == len(self.qsl)) if is_last_file: - log.info(f"Last file completed! Saving database and computing MD5...") + log.info( + f"Last file completed! Saving database and computing MD5...") # Save database db_path = f"{self.database}.db" save_start = time.time() @@ -287,7 +298,8 @@ def _process_document(self, query_sample): self.db_md5 = md5_hash.hexdigest() md5_end = time.time() md5_duration = md5_end - md5_start - log.info(f"Database MD5: {self.db_md5} (computed in {md5_duration:.2f}s)") + log.info( + f"Database MD5: {self.db_md5} (computed in {md5_duration:.2f}s)") self.db_saved = True @@ -317,7 +329,8 @@ def _process_document(self, query_sample): } # Create response for loadgen - # For the last file, include MD5 hash; otherwise just success/failure byte + # For the last file, include MD5 hash; otherwise just success/failure + # byte if success and self.db_saved: # Last file: return MD5 hash as response response_bytes = self.db_md5.encode('utf-8') @@ -339,7 +352,8 @@ def _process_document(self, query_sample): ) lg.QuerySamplesComplete([response]) - log.info(f"Completed document {sample_id} (QID: {query_id}): {file_name}") + log.info( + f"Completed document {sample_id} (QID: {query_id}): {file_name}") def flush_queries(self): """ @@ -349,10 +363,12 @@ def flush_queries(self): log.info("Flushing queries...") if self.db_saved: - log.info(f"Database already saved by last file. MD5: {self.db_md5}") + log.info( + f"Database already saved by last file. MD5: {self.db_md5}") else: # Fallback: save database if somehow not done yet - log.warning("Database not saved by last file - saving now as fallback") + log.warning( + "Database not saved by last file - saving now as fallback") db_path = f"{self.database}.db" save_start = time.time() self.rag_db.serialize(db_path) @@ -379,13 +395,15 @@ def finalize(self): failed_count = len(self.failed_documents) success_count = total_docs - failed_count - avg_time_per_doc = sum(self.processing_times) / total_docs if total_docs > 0 else 0 - throughput_passages = self.total_passages_indexed / total_time if total_time > 0 else 0 + avg_time_per_doc = sum( + self.processing_times) / total_docs if total_docs > 0 else 0 + throughput_passages = self.total_passages_indexed / \ + total_time if total_time > 0 else 0 throughput_docs = total_docs / total_time if total_time > 0 else 0 - log.info("="*80) + log.info("=" * 80) log.info("Datasetup Complete") - log.info("="*80) + log.info("=" * 80) log.info(f"Total documents processed: {total_docs}") log.info(f"Successful: {success_count}") log.info(f"Failed: {failed_count}") @@ -394,10 +412,11 @@ def finalize(self): log.info(f"Throughput: {throughput_passages:.2f} passages/sec") log.info(f"Throughput: {throughput_docs:.2f} docs/sec") log.info(f"Average time per document: {avg_time_per_doc:.2f}s") - log.info("="*80) + log.info("=" * 80) # Cleanup reranker queue if it exists - if hasattr(self.rag_db, '_reranker_queue') and self.rag_db._reranker_queue is not None: + if hasattr( + self.rag_db, '_reranker_queue') and self.rag_db._reranker_queue is not None: log.info("Shutting down reranker queue...") self.rag_db._reranker_queue.stop() @@ -419,12 +438,14 @@ def save_results(self, output_path): failed_count = len(self.failed_documents) success_count = total_docs - failed_count - throughput_passages = self.total_passages_indexed / total_time if total_time > 0 else 0 + throughput_passages = self.total_passages_indexed / \ + total_time if total_time > 0 else 0 throughput_docs = total_docs / total_time if total_time > 0 else 0 # Get vector count from database vector_count = 0 - if hasattr(self.rag_db, '_vector_store') and hasattr(self.rag_db._vector_store, 'index'): + if hasattr(self.rag_db, '_vector_store') and hasattr( + self.rag_db._vector_store, 'index'): vector_count = self.rag_db._vector_store.index.ntotal output = { diff --git a/e2e-rag/reference_mlperf.py b/e2e-rag/reference_mlperf.py index 83ebc0fcd0..2b306c5742 100644 --- a/e2e-rag/reference_mlperf.py +++ b/e2e-rag/reference_mlperf.py @@ -90,7 +90,8 @@ def get_args(): help="Number of queries for performance testing (None = all)" ) - # Multi-shot specific parameters (these are unique to multi_shot_retrieval.py) + # Multi-shot specific parameters (these are unique to + # multi_shot_retrieval.py) parser.add_argument( '--max-sub-queries', type=int, @@ -171,9 +172,9 @@ def main(): os.makedirs(args.output_dir, exist_ok=True) # Initialize SUT - print("\n" + "="*80) + print("\n" + "=" * 80) print("Initializing RAG-QnA SUT...") - print("="*80) + print("=" * 80) sut = E2ESUT( dataset_path=args.dataset_path, @@ -194,9 +195,9 @@ def main(): args=args, # Pass full args for additional params ) - print("\n" + "="*80) + print("\n" + "=" * 80) print("SUT initialization complete") - print("="*80 + "\n") + print("=" * 80 + "\n") # Configure loadgen settings settings = lg.TestSettings() @@ -227,9 +228,9 @@ def main(): log_settings.log_output = log_output_settings # Run loadgen test - print("\n" + "="*80) + print("\n" + "=" * 80) print("Running MLPerf Loadgen test...") - print("="*80 + "\n") + print("=" * 80 + "\n") lg.StartTestWithLogSettings( sut.sut, @@ -239,9 +240,9 @@ def main(): args.audit_conf ) - print("\n" + "="*80) + print("\n" + "=" * 80) print("Loadgen test complete") - print("="*80 + "\n") + print("=" * 80 + "\n") # Finalize SUT (save logs, cleanup) sut.finalize() @@ -253,9 +254,9 @@ def main(): # Run accuracy evaluation if in accuracy mode if args.accuracy: - print("\n" + "="*80) + print("\n" + "=" * 80) print("Running accuracy evaluation...") - print("="*80 + "\n") + print("=" * 80 + "\n") cmd = [ "python3", @@ -269,9 +270,9 @@ def main(): print(f"Command: {' '.join(cmd)}") subprocess.check_call(cmd) - print("\n" + "="*80) + print("\n" + "=" * 80) print("Done!") - print("="*80) + print("=" * 80) if __name__ == "__main__": diff --git a/e2e-rag/reference_mlperf_datasetup.py b/e2e-rag/reference_mlperf_datasetup.py index 5ec27d1232..647cd8c97b 100644 --- a/e2e-rag/reference_mlperf_datasetup.py +++ b/e2e-rag/reference_mlperf_datasetup.py @@ -171,9 +171,9 @@ def main(): os.makedirs(args.output_dir, exist_ok=True) # Initialize SUT - print("\n" + "="*80) + print("\n" + "=" * 80) print("Initializing RAG-DB SUT...") - print("="*80) + print("=" * 80) sut = DatasetupSUT( documents_dir=args.documents_dir, @@ -192,9 +192,9 @@ def main(): args=args, ) - print("\n" + "="*80) + print("\n" + "=" * 80) print("SUT initialization complete") - print("="*80 + "\n") + print("=" * 80 + "\n") # Configure loadgen settings settings = lg.TestSettings() @@ -225,9 +225,9 @@ def main(): log_settings.log_output = log_output_settings # Run loadgen test - print("\n" + "="*80) + print("\n" + "=" * 80) print("Running MLPerf Loadgen test...") - print("="*80 + "\n") + print("=" * 80 + "\n") lg.StartTestWithLogSettings( sut.sut, @@ -237,9 +237,9 @@ def main(): args.audit_conf ) - print("\n" + "="*80) + print("\n" + "=" * 80) print("Loadgen test complete") - print("="*80 + "\n") + print("=" * 80 + "\n") # Finalize SUT (batch index, save database, cleanup) sut.finalize() @@ -249,9 +249,9 @@ def main(): sut.save_results(results_path) print(f"Results saved to {results_path}") - print("\n" + "="*80) + print("\n" + "=" * 80) print("Done!") - print("="*80) + print("=" * 80) if __name__ == "__main__": diff --git a/e2e-rag/reranker_worker.py b/e2e-rag/reranker_worker.py index e3acfea045..fc0b568a82 100644 --- a/e2e-rag/reranker_worker.py +++ b/e2e-rag/reranker_worker.py @@ -100,7 +100,8 @@ def _reranker_worker_main( response_q.put((request_id, None, repr(e))) -def _do_rerank(model, tokenizer, device: str, query: str, passages: List[str]) -> List[Tuple[str, float]]: +def _do_rerank(model, tokenizer, device: str, query: str, + passages: List[str]) -> List[Tuple[str, float]]: """ColBERT late-interaction reranking with MaxSim scoring.""" import torch @@ -180,10 +181,12 @@ def start(self): self._process.start() if not self._ready_event.wait(timeout=300): - raise RuntimeError("reranker child failed to become ready within 300s") + raise RuntimeError( + "reranker child failed to become ready within 300s") self._dispatcher_running = True - self._dispatcher_thread = threading.Thread(target=self._dispatcher_loop, daemon=True) + self._dispatcher_thread = threading.Thread( + target=self._dispatcher_loop, daemon=True) self._dispatcher_thread.start() def stop(self): @@ -215,12 +218,14 @@ def _dispatcher_loop(self): continue event, container = slot if err is not None: - container["error"] = RuntimeError(f"reranker child error: {err}") + container["error"] = RuntimeError( + f"reranker child error: {err}") else: container["result"] = result event.set() - def submit(self, query: str, passages: List[str]) -> List[Tuple[str, float]]: + def submit(self, query: str, + passages: List[str]) -> List[Tuple[str, float]]: if self._process is None or not self._process.is_alive(): raise RuntimeError("reranker process is not running") diff --git a/e2e-rag/retrieve/__init__.py b/e2e-rag/retrieve/__init__.py index b14c776219..2e5187bee9 100644 --- a/e2e-rag/retrieve/__init__.py +++ b/e2e-rag/retrieve/__init__.py @@ -22,4 +22,4 @@ from .vectordb import VectorDB from .filter import filter, get_score_statistics -__all__ = ['RagDB', 'VectorDB', 'filter', 'get_score_statistics'] \ No newline at end of file +__all__ = ['RagDB', 'VectorDB', 'filter', 'get_score_statistics'] diff --git a/e2e-rag/retrieve/filter.py b/e2e-rag/retrieve/filter.py index 710977dac5..aa7e07b4d9 100644 --- a/e2e-rag/retrieve/filter.py +++ b/e2e-rag/retrieve/filter.py @@ -24,7 +24,7 @@ Implements various thresholding approaches: - Top-p (nucleus sampling) - popular in NLP - Score threshold - absolute quality bar -- Relative threshold - adaptive to query difficulty +- Relative threshold - adaptive to query difficulty - Elbow method - natural breakpoints - Percentile-based - statistical cutoffs @@ -40,14 +40,14 @@ def softmax(scores: List[float], temperature: float = 1.0) -> List[float]: """Convert scores to probabilities using softmax with temperature scaling. - + Args: scores: List of scores to convert temperature: Temperature parameter (lower = sharper distribution) - temperature = 1.0: standard softmax - temperature < 1.0: sharper (more weight on top scores) - temperature > 1.0: smoother (more uniform) - + Returns: List of probabilities that sum to 1.0 """ @@ -58,95 +58,98 @@ def softmax(scores: List[float], temperature: float = 1.0) -> List[float]: return [exp_s / sum_exp for exp_s in exp_scores] -def top_p_filter(results_with_scores: List[Tuple[Any, float]], p: float = 0.9) -> List[Any]: +def top_p_filter( + results_with_scores: List[Tuple[Any, float]], p: float = 0.9) -> List[Any]: """ Top-p (nucleus) sampling: Take results until cumulative probability >= p - + Args: results_with_scores: List of (result, score) tuples, sorted by score DESC p: Cumulative probability threshold (0.8-0.95 typical) - + Returns: Filtered results list """ if not results_with_scores: return [] - + scores = [score for _, score in results_with_scores] - - #print(f"\n[DEBUG top_p_filter] p={p}, num_candidates={len(scores)}") - #print(f"[DEBUG] Score range: [{min(scores):.4f}, {max(scores):.4f}]") - #print(f"[DEBUG] Score mean: {sum(scores)/len(scores):.4f}") - #print(f"[DEBUG] First 10 scores: {[f'{s:.4f}' for s in scores[:10]]}") - + + # print(f"\n[DEBUG top_p_filter] p={p}, num_candidates={len(scores)}") + # print(f"[DEBUG] Score range: [{min(scores):.4f}, {max(scores):.4f}]") + # print(f"[DEBUG] Score mean: {sum(scores)/len(scores):.4f}") + # print(f"[DEBUG] First 10 scores: {[f'{s:.4f}' for s in scores[:10]]}") + # Use temperature scaling to sharpen the distribution # Lower temperature = more discriminative (top docs get higher probability) # For vector embeddings with compressed L2 distances, use very low temperature # Temperature = 0.01 to 0.05 for L2 distances in range [0.27-0.42] temperature = 1 - #temperature = 0.02 + # temperature = 0.02 probs = softmax(scores, temperature=temperature) - - #print(f"[DEBUG] Temperature: {temperature}") - #print(f"[DEBUG] Probability range: [{min(probs):.6f}, {max(probs):.6f}]") - #print(f"[DEBUG] First 10 probs: {[f'{p:.6f}' for p in probs[:10]]}") - #print(f"[DEBUG] Prob sum: {sum(probs):.6f}") - + + # print(f"[DEBUG] Temperature: {temperature}") + # print(f"[DEBUG] Probability range: [{min(probs):.6f}, {max(probs):.6f}]") + # print(f"[DEBUG] First 10 probs: {[f'{p:.6f}' for p in probs[:10]]}") + # print(f"[DEBUG] Prob sum: {sum(probs):.6f}") + cumulative_prob = 0.0 selected_results = [] - - for i, ((result, score), prob) in enumerate(zip(results_with_scores, probs)): + + for i, ((result, score), prob) in enumerate( + zip(results_with_scores, probs)): cumulative_prob += prob selected_results.append(result) - + if cumulative_prob >= p: - print(f"[DEBUG] Selected {i+1} documents (cumulative_prob={cumulative_prob:.4f} >= p={p})") + print( + f"[DEBUG] Selected {i+1} documents (cumulative_prob={cumulative_prob:.4f} >= p={p})") break - + return selected_results -def score_threshold_filter(results_with_scores: List[Tuple[Any, float]], - threshold: float, higher_better: bool = True) -> List[Any]: +def score_threshold_filter(results_with_scores: List[Tuple[Any, float]], + threshold: float, higher_better: bool = True) -> List[Any]: """ Absolute score threshold filtering. - + Args: results_with_scores: List of (result, score) tuples threshold: Absolute score cutoff higher_better: If True, keep scores >= threshold, else <= threshold - + Returns: Filtered results list """ selected_results = [] - + for result, score in results_with_scores: if higher_better and score >= threshold: selected_results.append(result) elif not higher_better and score <= threshold: selected_results.append(result) - + return selected_results -def relative_threshold_filter(results_with_scores: List[Tuple[Any, float]], - ratio: float = 0.8) -> List[Any]: +def relative_threshold_filter(results_with_scores: List[Tuple[Any, float]], + ratio: float = 0.8) -> List[Any]: """ Relative threshold: Keep top ratio fraction of results based on score range. - + For both positive and negative scores: - Calculates score range between best and worst - Keeps only results within top ratio% of that range - + Args: results_with_scores: List of (result, score) tuples (sorted desc, best first) ratio: Fraction of score range to keep (0.7-0.9 typical) e.g., 0.9 means keep top 90% of score range - + Returns: Filtered results list - + Example with negative scores: Scores: [-0.42, -0.43, -0.44, ..., -0.49] best=-0.42, worst=-0.49, range=0.07 @@ -156,85 +159,88 @@ def relative_threshold_filter(results_with_scores: List[Tuple[Any, float]], """ if not results_with_scores: return [] - + # Get best and worst scores best_score = results_with_scores[0][1] worst_score = results_with_scores[-1][1] - + # Calculate the score range score_range = best_score - worst_score # Always positive since sorted desc - + # Calculate threshold: start from best, move down by (1-ratio) of range cutoff_distance = score_range * (1 - ratio) threshold = best_score - cutoff_distance - - return score_threshold_filter(results_with_scores, threshold, higher_better=True) + return score_threshold_filter( + results_with_scores, threshold, higher_better=True) -def elbow_method_filter(results_with_scores: List[Tuple[Any, float]]) -> List[Any]: + +def elbow_method_filter( + results_with_scores: List[Tuple[Any, float]]) -> List[Any]: """ Elbow method: Find largest score gap and cut there. - + Args: results_with_scores: List of (result, score) tuples, sorted by score DESC - + Returns: Filtered results list """ if len(results_with_scores) <= 1: return [result for result, _ in results_with_scores] - + scores = [score for _, score in results_with_scores] - + # Calculate gaps between consecutive scores gaps = [] for i in range(len(scores) - 1): gap = scores[i] - scores[i + 1] # Assuming DESC order gaps.append(gap) - + # Find largest gap if not gaps: return [result for result, _ in results_with_scores] - + max_gap_idx = gaps.index(max(gaps)) cutoff_point = max_gap_idx + 1 # Include the score before the gap - + return [result for result, _ in results_with_scores[:cutoff_point]] -def percentile_filter(results_with_scores: List[Tuple[Any, float]], - percentile: float = 90.0) -> List[Any]: +def percentile_filter(results_with_scores: List[Tuple[Any, float]], + percentile: float = 90.0) -> List[Any]: """ Percentile-based filtering: Keep top X percentile of scores. - + Args: results_with_scores: List of (result, score) tuples percentile: Percentile threshold (80-95 typical) - + Returns: Filtered results list """ if not results_with_scores: return [] - + scores = [score for _, score in results_with_scores] threshold = np.percentile(scores, percentile) - - return score_threshold_filter(results_with_scores, threshold, higher_better=True) + + return score_threshold_filter( + results_with_scores, threshold, higher_better=True) -def filter(rag_db, query: str, method: str = "top_p", +def filter(rag_db, query: str, method: str = "top_p", max_results: int = 100, **kwargs) -> List[Any]: """ Perform adaptive retrieval using score-based filtering. - + Args: rag_db: RAG database instance (VectorDB) query: Search query method: Filtering method ("top_p", "score_threshold", "relative", "elbow", "percentile") max_results: Maximum results to retrieve initially **kwargs: Method-specific parameters - + Returns: Filtered results list """ @@ -242,32 +248,33 @@ def filter(rag_db, query: str, method: str = "top_p", if hasattr(rag_db, 'lookup_with_scores'): results_with_scores = rag_db.lookup_with_scores(query, k=max_results) else: - raise ValueError(f"Database {type(rag_db)} doesn't support score-based retrieval") - + raise ValueError( + f"Database {type(rag_db)} doesn't support score-based retrieval") + # Results already have proper similarity scores (higher is better) from lookup_with_scores # Sort by score (descending - higher is better) results_with_scores.sort(key=lambda x: x[1], reverse=True) - + # Apply filtering method if method == "top_p": p = kwargs.get("p", 0.9) return top_p_filter(results_with_scores, p) - + elif method == "score_threshold": threshold = kwargs.get("threshold", 5.0) return score_threshold_filter(results_with_scores, threshold) - + elif method == "relative": ratio = kwargs.get("ratio", 0.8) return relative_threshold_filter(results_with_scores, ratio) - + elif method == "elbow": return elbow_method_filter(results_with_scores) - + elif method == "percentile": percentile = kwargs.get("percentile", 90.0) return percentile_filter(results_with_scores, percentile) - + else: raise ValueError(f"Unknown filtering method: {method}") @@ -276,10 +283,10 @@ def get_score_statistics(rag_db, query: str, k: int = 100) -> Dict[str, float]: """Get score distribution statistics for threshold calibration.""" results_with_scores = rag_db.lookup_with_scores(query, k=k) scores = [score for _, score in results_with_scores] - + if not scores: return {} - + return { "min": min(scores), "max": max(scores), @@ -294,4 +301,4 @@ def get_score_statistics(rag_db, query: str, k: int = 100) -> Dict[str, float]: # Backward compatibility alias -adaptive_retrieval = filter \ No newline at end of file +adaptive_retrieval = filter diff --git a/e2e-rag/retrieve/ragdb.py b/e2e-rag/retrieve/ragdb.py index 19f65c4416..e4c7e1ba79 100644 --- a/e2e-rag/retrieve/ragdb.py +++ b/e2e-rag/retrieve/ragdb.py @@ -18,15 +18,17 @@ import os from typing import List, Dict, Any + class RagDB(abc.ABC): """Base class for retrieval-augmented generation databases.""" - + def __init__(self, reranker_model: str = None, device: str = "auto", benchmark: bool = False, reranker_device: str = None): self._reranker_model_name = reranker_model self._device = self._determine_device(device) # Reranker device defaults to inheriting from --device. - self._reranker_device = self._determine_device(reranker_device) if reranker_device else self._device + self._reranker_device = self._determine_device( + reranker_device) if reranker_device else self._device self._reranker_queue = None self._benchmark = benchmark self._monitor = None @@ -39,12 +41,12 @@ def __init__(self, reranker_model: str = None, device: str = "auto", # Initialize out-of-process reranker if specified if self._reranker_model_name: self._init_reranker() - + def _determine_device(self, device: str) -> str: """Determine the best device to use. Delegates to utils.detect_device() for auto detection so device-selection - logic lives in one place. ROCm maps to "cuda" + logic lives in one place. ROCm maps to "cuda" """ if device == "rocm": return "cuda" @@ -52,14 +54,14 @@ def _determine_device(self, device: str) -> str: from utils import detect_device return detect_device() return device - + @staticmethod def get_data_dir(db_name: str) -> str: """Get data directory based on database name.""" from pathlib import Path base_name = Path(db_name).stem # Remove .db extension if present return f"{base_name}_data" - + @staticmethod def get_db_path(db_name: str) -> str: """Get database file path based on database name.""" @@ -90,11 +92,11 @@ def _init_reranker(self): omp_threads=int(omp_threads) if omp_threads else None, ) self._reranker_queue.start() - - def _track_component(self, name: str, total_chars: int, item_count: int, func, - is_pipeline_input: bool = False, is_pipeline_output: bool = False): + + def _track_component(self, name: str, total_chars: int, item_count: int, func, + is_pipeline_input: bool = False, is_pipeline_output: bool = False): """Execute function with optional component tracking. - + Args: name: Component name total_chars: Input size in bytes @@ -104,26 +106,27 @@ def _track_component(self, name: str, total_chars: int, item_count: int, func, is_pipeline_output: Mark as pipeline output for aggregation """ if self._benchmark and self._monitor: - with self._monitor.track_component(name, input_size_bytes=total_chars, - items_count=item_count, text_only=True, - is_pipeline_input=is_pipeline_input, - is_pipeline_output=is_pipeline_output) as ctx: + with self._monitor.track_component(name, input_size_bytes=total_chars, + items_count=item_count, text_only=True, + is_pipeline_input=is_pipeline_input, + is_pipeline_output=is_pipeline_output) as ctx: result = func() ctx.add_text_bytes(total_chars) return result else: return func() - + def _start_ingestion_timer(self): """Start the ingestion timer. Works for both benchmark and non-benchmark modes.""" import time if self._benchmark and self._monitor: self._monitor.start_ingestion() return time.perf_counter() - - def _report_performance(self, ingestion_start_time: float, item_count: int, total_chars: int, db_type: str): + + def _report_performance(self, ingestion_start_time: float, + item_count: int, total_chars: int, db_type: str): """Report performance metrics with optional detailed breakdown. - + Args: ingestion_start_time: Start time from _start_ingestion_timer() (used only in non-benchmark mode) item_count: Number of items processed @@ -131,7 +134,7 @@ def _report_performance(self, ingestion_start_time: float, item_count: int, tota db_type: Database type string for display """ import time - + if self._benchmark and self._monitor: with self._monitor.track_ingestion() as ingestion_ctx: ingestion_ctx.set_item_count(item_count) @@ -142,9 +145,11 @@ def _report_performance(self, ingestion_start_time: float, item_count: int, tota duration = end_time - ingestion_start_time docs_per_sec = item_count / duration if duration > 0 else 0 chars_per_sec = total_chars / duration if duration > 0 else 0 - print(f"{db_type} ingestion: {item_count} docs, {total_chars:,} chars in {duration:.2f}s") - print(f" Performance: {docs_per_sec:.1f} docs/sec, {chars_per_sec/1024:.1f} KB/sec") - + print( + f"{db_type} ingestion: {item_count} docs, {total_chars:,} chars in {duration:.2f}s") + print( + f" Performance: {docs_per_sec:.1f} docs/sec, {chars_per_sec/1024:.1f} KB/sec") + def enable_threading(self): """Enable thread-safe access. Override in subclasses that need locks.""" pass @@ -153,26 +158,27 @@ def enable_threading(self): def ingest(self, passages: List[str], metadatas: List[Dict[str, Any]]): """Ingest passages and their metadata into the database.""" pass - + @abc.abstractmethod def lookup(self, query: str, k: int) -> List[Any]: """Retrieve top-k relevant passages for a query.""" pass - + @abc.abstractmethod def serialize(self, path: str): """Serialize the database to disk.""" pass - + @abc.abstractmethod def from_serialized(self, path: str): """Load the database from disk.""" pass - + def ingest_from_folder(self, folder_path: str, **kwargs): """Ingest data from a folder. Default implementation raises NotImplementedError.""" - raise NotImplementedError(f"Folder ingestion not supported for {self.__class__.__name__}") - + raise NotImplementedError( + f"Folder ingestion not supported for {self.__class__.__name__}") + def ingest_from_file(self, file_path: str, **kwargs): """Ingest data from a JSON file. Default implementation for JSON files. @@ -204,35 +210,39 @@ def ingest_from_file(self, file_path: str, **kwargs): # Use child for embedding doc_list.append(entry['child_passage']) # Store all metadata including parent - metadata = {k: v for k, v in entry.items() if k != 'child_passage'} + metadata = {k: v for k, v in entry.items() if k != + 'child_passage'} passage_metadata.append(metadata) else: # Flat format for entry in passage_data: doc_list.append(entry['passage']) - passage_metadata.append({k: v for k, v in entry.items() if k != 'passage'}) + passage_metadata.append( + {k: v for k, v in entry.items() if k != 'passage'}) print(f"Ingesting {len(doc_list)} passages from JSON file {file_path}") - return self.ingest(doc_list, passage_metadata, passages_path=file_path, **kwargs) + return self.ingest(doc_list, passage_metadata, + passages_path=file_path, **kwargs) def ingest_from_path(self, source_path: str, **kwargs): """Handle both file and folder ingestion. - + Default implementation that delegates to appropriate methods: - Folders: calls ingest_from_folder() (may raise NotImplementedError if not overridden) - Files: calls ingest_from_file() (default JSON implementation) """ from pathlib import Path - + source_path = Path(source_path) - + if source_path.is_dir(): print(f"Ingesting documents from folder {source_path}") return self.ingest_from_folder(source_path, **kwargs) elif source_path.is_file(): return self.ingest_from_file(source_path, **kwargs) else: - raise ValueError(f"Source path {source_path} is neither a file nor a directory") + raise ValueError( + f"Source path {source_path} is neither a file nor a directory") def shutdown_reranker(self): """Tear down the reranker child process. Safe to call multiple times.""" @@ -246,7 +256,8 @@ def rerank(self, query: str, passages: List[str]): return self._reranker_queue.submit(query, passages) return [(p, 0.0) for p in passages] - def lookup_with_rerank(self, query: str, k: int, rerank_k: int = None) -> List[Any]: + def lookup_with_rerank(self, query: str, k: int, + rerank_k: int = None) -> List[Any]: """Retrieve and rerank passages.""" if rerank_k is None: rerank_k = k @@ -257,13 +268,13 @@ def lookup_with_rerank(self, query: str, k: int, rerank_k: int = None) -> List[A # If no reranker or fewer results than requested, return as-is if self._reranker_queue is None or len(results) <= k: return results[:k] - + # Extract passages for reranking passages = [result.page_content for result in results] - + # Rerank reranked_passages = self.rerank(query, passages) - + # Map back to original results and return top-k reranked_results = [] for passage, score in reranked_passages[:k]: @@ -271,7 +282,7 @@ def lookup_with_rerank(self, query: str, k: int, rerank_k: int = None) -> List[A if result.page_content == passage: reranked_results.append(result) break - + return reranked_results @property diff --git a/e2e-rag/retrieve/vectordb.py b/e2e-rag/retrieve/vectordb.py index 73499e93d5..d0c392568a 100644 --- a/e2e-rag/retrieve/vectordb.py +++ b/e2e-rag/retrieve/vectordb.py @@ -44,7 +44,11 @@ def _alias_missing_faiss_swigfaiss_modules() -> None: serialized databases loadable across heterogeneous installs without rebuilding them. """ - candidates = ("swigfaiss_avx512_spr", "swigfaiss_avx512", "swigfaiss_avx2", "swigfaiss") + candidates = ( + "swigfaiss_avx512_spr", + "swigfaiss_avx512", + "swigfaiss_avx2", + "swigfaiss") available = None for name in candidates: full = f"faiss.{name}" @@ -66,7 +70,10 @@ def _alias_missing_faiss_swigfaiss_modules() -> None: except ImportError: sys.modules[full] = available -# Worker function for parallel embedding generation (must be at module level for multiprocessing) +# Worker function for parallel embedding generation (must be at module +# level for multiprocessing) + + def _parallel_embed_worker(device_id, input_chunk_indices, input_chunks, result_queue, model_name, encode_kwargs, base_device, numa_plan=None): """Worker function to generate embeddings on a specific device. @@ -120,9 +127,11 @@ def _parallel_embed_worker(device_id, input_chunk_indices, input_chunks, result_ encode_kwargs=encode_kwargs ) - print(f"✓ Device {device}: Loaded model, processing {len(input_chunks)} input chunk(s)") + print( + f"✓ Device {device}: Loaded model, processing {len(input_chunks)} input chunk(s)") - for input_chunk_idx, input_chunk in zip(input_chunk_indices, input_chunks): + for input_chunk_idx, input_chunk in zip( + input_chunk_indices, input_chunks): embeddings = embedder.embed_documents(input_chunk) result_queue.put((input_chunk_idx, embeddings)) @@ -133,24 +142,25 @@ def _parallel_embed_worker(device_id, input_chunk_indices, input_chunks, result_ for input_chunk_idx in input_chunk_indices: result_queue.put((input_chunk_idx, None)) + class VectorDB(RagDB): @classmethod def get_default_db_name(cls) -> str: """Get the default database filename for VectorDB.""" return "vector.db" - + def __init__(self, - retriever_model: str = None, - reranker_model: str = None, - device: str = "auto", - load_embeddings: bool = True, - num_embedding_devices: int = 1, - benchmark: bool = False, - hierarchical: bool = False, - embedding_device: str = None, - reranker_device: str = None, - **kwargs - ): + retriever_model: str = None, + reranker_model: str = None, + device: str = "auto", + load_embeddings: bool = True, + num_embedding_devices: int = 1, + benchmark: bool = False, + hierarchical: bool = False, + embedding_device: str = None, + reranker_device: str = None, + **kwargs + ): super().__init__(reranker_model, device, benchmark, reranker_device=reranker_device) self._retriever_model_name = retriever_model self._reranker_model_name = reranker_model @@ -159,7 +169,8 @@ def __init__(self, self._hierarchical = hierarchical self._embedding_lock = None # Embedding device defaults to inheriting from --device. - self._embedding_device = self._determine_device(embedding_device) if embedding_device else self._device + self._embedding_device = self._determine_device( + embedding_device) if embedding_device else self._device # For hierarchical mode: map child_index -> parent_passage self._parent_map = {} @@ -175,7 +186,8 @@ def __init__(self, # For single-device embedding, allocate one GPU now so the reranker # (allocated later) can't pick the same one. Multi-device path # allocates inside _embed_documents_parallel. - if num_embedding_devices == 1 and self._embedding_device in ("cuda", "xpu"): + if num_embedding_devices == 1 and self._embedding_device in ( + "cuda", "xpu"): from utils import resolve_gpu_device self._embedding_device = resolve_gpu_device( self._embedding_device, name="embedding", @@ -183,27 +195,32 @@ def __init__(self, ) # Initialize embedding model with device configuration - model_kwargs = {'device': self._embedding_device, 'local_files_only': True} + model_kwargs = { + 'device': self._embedding_device, + 'local_files_only': True} encode_kwargs = {'normalize_embeddings': True} - + self._embedding_model = HuggingFaceEmbeddings( model_name=self._retriever_model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs ) - self._embedding_dimension = len(self._embedding_model.embed_query("hello world")) - + self._embedding_dimension = len( + self._embedding_model.embed_query("hello world")) + # Check the dtype of the embedding without using numpy test_embedding_raw = self._embedding_model.embed_query("test") - + # Calculate dtype and itemsize from Python native list - if isinstance(test_embedding_raw, list) and len(test_embedding_raw) > 0: + if isinstance(test_embedding_raw, list) and len( + test_embedding_raw) > 0: test_element = test_embedding_raw[0] embedding_dtype = type(test_element) embedding_itemsize = test_element.__sizeof__() # Size in bytes of one element self._embedding_bytes_per_element = embedding_itemsize else: - raise ValueError("Embedding query did not return a valid list of floats.") + raise ValueError( + "Embedding query did not return a valid list of floats.") if self._benchmark: print(f" Embedding element type: {embedding_dtype}") @@ -220,12 +237,12 @@ def __init__(self, embedding_function=self._embedding_model, index=self._index, docstore=self._docstore, - index_to_docstore_id={}, # This will be populated as documents are added + index_to_docstore_id={}, # This will be populated as documents are added ) - + # Keep track of ingested documents self._doc_list = [] - + def _create_vector_index(self, dimension: int): """Create a FAISS HNSW vector index. @@ -241,12 +258,14 @@ def _create_vector_index(self, dimension: int): FAISS HNSW index """ # M: number of connections per layer (higher = better recall, more memory) - # efConstruction: quality of index construction (higher = better quality, slower build) + # efConstruction: quality of index construction (higher = better + # quality, slower build) M = 32 # Default: 32, good balance index = faiss.IndexHNSWFlat(dimension, M) index.hnsw.efConstruction = 200 # Default: 40 index.hnsw.efSearch = 100 # Search-time parameter, can be adjusted later return index + def _get_embeddings_cache_path(self, passages_path: str) -> str: """Get the cache path for embeddings based on passages file path.""" from pathlib import Path @@ -254,15 +273,16 @@ def _get_embeddings_cache_path(self, passages_path: str) -> str: # Replace extension with .emb.pkl cache_path = passages_path.with_suffix('.emb.pkl') return str(cache_path) - + def _save_embeddings_cache(self, embeddings: list, passages_path: str): """Save embeddings to a pickle file for reuse.""" - import os, pickle + import os + import pickle from pathlib import Path - + cache_path = self._get_embeddings_cache_path(passages_path) Path(cache_path).parent.mkdir(parents=True, exist_ok=True) - + if os.path.exists(cache_path): print(f"Embeddings cache exists: {cache_path}") return @@ -270,16 +290,16 @@ def _save_embeddings_cache(self, embeddings: list, passages_path: str): with open(cache_path, 'wb') as f: pickle.dump(embeddings, f) print(f"💾 Saved embeddings cache to {cache_path}") - + def _load_embeddings_cache(self, passages_path: str) -> list: """Load embeddings from cache if available.""" import pickle from pathlib import Path - + cache_path = self._get_embeddings_cache_path(passages_path) if not Path(cache_path).exists(): return None - + try: with open(cache_path, 'rb') as f: embeddings = pickle.load(f) @@ -288,7 +308,7 @@ def _load_embeddings_cache(self, passages_path: str) -> list: except Exception as e: print(f"⚠️ Failed to load embeddings cache: {e}") return None - + def _build_numa_plans(self, num_workers: int): """Parse INFERENCE_EMBEDDING_NUMA_NODES into per-worker (node, cpu_set). @@ -319,22 +339,23 @@ def _build_numa_plans(self, num_workers: int): def _embed_documents_parallel(self, passages: List[str]) -> list: """Generate embeddings using multiple devices in parallel. - + Uses the device type from --device option and spawns multiple workers. - + Args: passages: List of text passages to embed - + Returns: List of embeddings (one per passage) """ import torch import multiprocessing as mp - + # Use the embedding device (may differ from the global --device). # Strip any GPU index already allocated for the single-device path so # the parallel allocator gets a clean device-type string. - base_device = self._embedding_device.split(":")[0] # e.g., 'xpu', 'cuda', 'cpu', 'hpu' + base_device = self._embedding_device.split( + ":")[0] # e.g., 'xpu', 'cuda', 'cpu', 'hpu' num_workers = min(self._num_embedding_devices, len(passages)) @@ -350,7 +371,8 @@ def _embed_documents_parallel(self, passages: List[str]) -> list: override_env="INFERENCE_EMBEDDING_GPU_DEVICES", ) else: - # CPU / HPU: workers use the same device string; index field is ignored. + # CPU / HPU: workers use the same device string; index field is + # ignored. device_indices = list(range(num_workers)) # Set spawn method for device compatibility (required for XPU/CUDA) @@ -360,16 +382,19 @@ def _embed_documents_parallel(self, passages: List[str]) -> list: # Already set, ignore pass - print(f"🚀 Parallel embedding on {num_workers} {base_device.upper()} device(s)...") + print( + f"🚀 Parallel embedding on {num_workers} {base_device.upper()} device(s)...") # Optional per-worker NUMA pinning (CPU + memory + OMP). numa_plans = self._build_numa_plans(num_workers) # Split passages into per-worker input chunks. input_chunk_size = (len(passages) + num_workers - 1) // num_workers - input_chunks = [passages[i:i + input_chunk_size] for i in range(0, len(passages), input_chunk_size)] + input_chunks = [passages[i:i + input_chunk_size] + for i in range(0, len(passages), input_chunk_size)] - print(f" Split {len(passages)} passages into {len(input_chunks)} input chunk(s) (~{input_chunk_size} passages/device)") + print( + f" Split {len(passages)} passages into {len(input_chunks)} input chunk(s) (~{input_chunk_size} passages/device)") result_queue = mp.Queue() processes = [] @@ -379,11 +404,12 @@ def _embed_documents_parallel(self, passages: List[str]) -> list: # input_chunk_idx is the position in `input_chunks`; device_id is the GPU # index from the allocator. They differ when the allocator returns # non-contiguous indices. - for input_chunk_idx, device_id in enumerate(device_indices[:len(input_chunks)]): + for input_chunk_idx, device_id in enumerate( + device_indices[:len(input_chunks)]): numa_plan = numa_plans[input_chunk_idx] if numa_plans else None p = mp.Process(target=_parallel_embed_worker, - args=(device_id, [input_chunk_idx], [input_chunks[input_chunk_idx]], result_queue, - self._retriever_model_name, encode_kwargs, base_device, numa_plan)) + args=(device_id, [input_chunk_idx], [input_chunks[input_chunk_idx]], result_queue, + self._retriever_model_name, encode_kwargs, base_device, numa_plan)) p.start() processes.append(p) @@ -401,36 +427,37 @@ def _embed_documents_parallel(self, passages: List[str]) -> list: for i in range(len(input_chunks)): if i in results: all_embeddings.extend(results[i]) - - print(f"✓ Generated {len(all_embeddings)} embeddings across {num_workers} devices") - + + print( + f"✓ Generated {len(all_embeddings)} embeddings across {num_workers} devices") + return all_embeddings - + def _calculate_index_output_size(self): """Calculate the size of VectorDB output data (db file - metadata). - + Returns the total size in bytes of the serialized database file, excluding configuration metadata overhead. - + The .db file contains: - FAISS index (vectors) - Passages (docstore) - Metadata (small overhead) - + We estimate metadata size and subtract it from total file size. """ from pathlib import Path - + if not hasattr(self, '_serialize_path') or not self._serialize_path: return 0 - + db_path = Path(self._serialize_path) if not db_path.exists(): return 0 - + total_file_size = db_path.stat().st_size return total_file_size - + def ingest(self, passages: List[str], metadatas: List[dict], **kwargs): """Ingest passages with performance monitoring. @@ -457,7 +484,7 @@ def ingest(self, passages: List[str], metadatas: List[dict], **kwargs): print(f" Stored {len(self._parent_map)} parent mappings") total_chars = sum(len(passage) for passage in passages) - + # Handle embeddings: try to load from cache or generate new ones embeddings = None @@ -468,44 +495,58 @@ def ingest(self, passages: List[str], metadatas: List[dict], **kwargs): if embeddings is None: if self._num_embedding_devices > 1: # Use parallel embedding generation across multiple devices - embeddings = self._track_component("embedding_generation", total_chars, len(passages), - lambda: self._embed_documents_parallel(passages), - is_pipeline_input=True) + embeddings = self._track_component("embedding_generation", total_chars, len(passages), + lambda: self._embed_documents_parallel( + passages), + is_pipeline_input=True) else: # Single device embedding generation - embeddings = self._track_component("embedding_generation", total_chars, len(passages), - lambda: self._embedding_model.embed_documents(passages), - is_pipeline_input=True) - - - # Determine batch size: single batch for small datasets, multiple batches for scaling analysis - track_incremental = self._benchmark and self._monitor and len(passages) >= 500 + embeddings = self._track_component("embedding_generation", total_chars, len(passages), + lambda: self._embedding_model.embed_documents( + passages), + is_pipeline_input=True) + + # Determine batch size: single batch for small datasets, multiple + # batches for scaling analysis + track_incremental = self._benchmark and self._monitor and len( + passages) >= 500 if track_incremental: - batch_size = max(1000, len(passages) // 10) # 10 batches, minimum 1000 docs per batch - print(f"🔬 Incremental indexing analysis: {len(passages)} docs in batches of {batch_size}") + # 10 batches, minimum 1000 docs per batch + batch_size = max(1000, len(passages) // 10) + print( + f"🔬 Incremental indexing analysis: {len(passages)} docs in batches of {batch_size}") else: batch_size = len(passages) # Single batch - + # Track total indexing time for component metrics import time indexing_component_start = time.perf_counter() - + # Process in batches for i in range(0, len(passages), batch_size): batch_end = min(i + batch_size, len(passages)) - self._ingest_single_batch(passages, metadatas, embeddings, i, batch_end, track_incremental) - + self._ingest_single_batch( + passages, + metadatas, + embeddings, + i, + batch_end, + track_incremental) + indexing_component_end = time.perf_counter() indexing_component_duration = indexing_component_end - indexing_component_start - + # Create component metrics for the entire indexing operation if not track_incremental: - # For single batch, component was tracked inside _ingest_single_batch + # For single batch, component was tracked inside + # _ingest_single_batch pass elif self._monitor: - # For incremental, create component metrics here for the entire operation - embedding_bytes = len(passages) * self._embedding_dimension * self._embedding_bytes_per_element - + # For incremental, create component metrics here for the entire + # operation + embedding_bytes = len( + passages) * self._embedding_dimension * self._embedding_bytes_per_element + from ingestion_monitor import ComponentMetrics self._monitor.components["faiss_indexing"] = ComponentMetrics( name="faiss_indexing", @@ -513,25 +554,27 @@ def ingest(self, passages: List[str], metadatas: List[dict], **kwargs): input_size_bytes=embedding_bytes, output_size_bytes=embedding_bytes, # Vectors stored in FAISS index items_processed=len(passages), - throughput_mb_per_sec=(embedding_bytes / (1024 * 1024)) / indexing_component_duration if indexing_component_duration > 0 else 0, - throughput_items_per_sec=len(passages) / indexing_component_duration if indexing_component_duration > 0 else 0, + throughput_mb_per_sec=(embedding_bytes / (1024 * 1024)) / + indexing_component_duration if indexing_component_duration > 0 else 0, + throughput_items_per_sec=len( + passages) / indexing_component_duration if indexing_component_duration > 0 else 0, is_pipeline_input=False, is_pipeline_output=True ) - + # Store ingestion metrics for later reporting self._ingestion_start = ingestion_start self._ingestion_item_count = len(passages) self._ingestion_total_chars = total_chars - - # Save embeddings to cache + + # Save embeddings to cache if self._load_embeddings and passages_path: self._save_embeddings_cache(embeddings, passages_path) def _ingest_single_batch(self, passages: List[str], metadatas: List[dict], embeddings: list, - batch_start: int, batch_end: int, track_incremental: bool): + batch_start: int, batch_end: int, track_incremental: bool): """Ingest a batch of passages. Can be used for single or incremental indexing. - + Args: passages: All passages metadatas: All metadata @@ -541,40 +584,42 @@ def _ingest_single_batch(self, passages: List[str], metadatas: List[dict], embed track_incremental: Whether to track this batch for incremental analysis """ import time - + # Extract batch data batch_passages = passages[batch_start:batch_end] - batch_metadatas = metadatas[batch_start:batch_end] if metadatas else [{}] * (batch_end - batch_start) + batch_metadatas = metadatas[batch_start:batch_end] if metadatas else [ + {}] * (batch_end - batch_start) batch_embeddings = embeddings[batch_start:batch_end] - + # Track DB size before adding (for incremental tracking) db_size_before = len(self._doc_list) if track_incremental else 0 - + # Calculate embedding size for this batch - batch_embedding_bytes = len(batch_passages) * self._embedding_dimension * self._embedding_bytes_per_element - + batch_embedding_bytes = len( + batch_passages) * self._embedding_dimension * self._embedding_bytes_per_element + # Time and execute indexing operation indexing_start = time.perf_counter() - + if track_incremental: # For incremental: just add embeddings without component tracking self._vector_store.add_embeddings( - list(zip(batch_passages, batch_embeddings)), + list(zip(batch_passages, batch_embeddings)), batch_metadatas ) else: # For single batch: use component tracking self._track_component("faiss_indexing", batch_embedding_bytes, len(batch_passages), - lambda: self._vector_store.add_embeddings( - list(zip(batch_passages, batch_embeddings)), batch_metadatas), - is_pipeline_output=True) - + lambda: self._vector_store.add_embeddings( + list(zip(batch_passages, batch_embeddings)), batch_metadatas), + is_pipeline_output=True) + indexing_end = time.perf_counter() indexing_time = indexing_end - indexing_start - + # Update document list self._doc_list.extend(batch_passages) - + # Track for incremental analysis if requested if track_incremental and self._monitor: self._monitor.track_incremental_indexing( @@ -582,7 +627,7 @@ def _ingest_single_batch(self, passages: List[str], metadatas: List[dict], embed batch_size=len(batch_passages), indexing_time=indexing_time ) - + def enable_threading(self): """Enable thread-safe access to the embedding model.""" import threading @@ -598,9 +643,11 @@ def lookup(self, query: str, k: int): embedding = self.embed_query(query) else: embedding = self.embed_query(query) - results = self._vector_store.similarity_search_by_vector(embedding, k=k) + results = self._vector_store.similarity_search_by_vector( + embedding, k=k) - # In hierarchical mode: replace child passages with parents, deduplicate + # In hierarchical mode: replace child passages with parents, + # deduplicate if self._hierarchical and self._parent_map: parent_docs = [] seen_parents = set() @@ -631,7 +678,7 @@ def lookup(self, query: str, k: int): return parent_docs return results - + def lookup_with_scores(self, query: str, k: int): """ Lookup documents with similarity scores. @@ -645,13 +692,16 @@ def lookup_with_scores(self, query: str, k: int): embedding = self.embed_query(query) else: embedding = self.embed_query(query) - results_with_scores = self._vector_store.similarity_search_with_score_by_vector(embedding, k=k) + results_with_scores = self._vector_store.similarity_search_with_score_by_vector( + embedding, k=k) # FAISS returns (document, distance) where distance is L2 distance (lower is better) # Convert to similarity score (higher is better) by negating - results_with_similarity = [(doc, -distance) for doc, distance in results_with_scores] + results_with_similarity = [(doc, -distance) + for doc, distance in results_with_scores] - # In hierarchical mode: replace child passages with parents, deduplicate + # In hierarchical mode: replace child passages with parents, + # deduplicate if self._hierarchical and self._parent_map: parent_results = [] seen_parents = set() @@ -683,8 +733,6 @@ def lookup_with_scores(self, query: str, k: int): return results_with_similarity - - def serialize(self, path: str): # Store path for output size calculation self._serialize_path = path @@ -700,21 +748,27 @@ def serialize(self, path: str): parent_map_path = Path(path).with_suffix('.parent_map.pkl') with open(parent_map_path, 'wb') as f: pickle.dump(self._parent_map, f) - print(f"💾 Saved parent map ({len(self._parent_map)} entries) to {parent_map_path}") + print( + f"💾 Saved parent map ({len(self._parent_map)} entries) to {parent_map_path}") # Update output size after serialization (now file exists) if self._benchmark and self._monitor: - self._monitor.set_output_size_callback("faiss_indexing", self._calculate_index_output_size) + self._monitor.set_output_size_callback( + "faiss_indexing", self._calculate_index_output_size) # Report performance after serialization if benchmarking - if self._benchmark and self._monitor and hasattr(self, '_ingestion_start'): + if self._benchmark and self._monitor and hasattr( + self, '_ingestion_start'): # Determine db_type based on whether incremental was used - db_type = "VectorDB (Incremental)" if hasattr(self._monitor, 'indexing_trend') and len(self._monitor.indexing_trend) > 0 else "VectorDB" + db_type = "VectorDB (Incremental)" if hasattr( + self._monitor, 'indexing_trend') and len( + self._monitor.indexing_trend) > 0 else "VectorDB" self._report_performance(self._ingestion_start, self._ingestion_item_count, - self._ingestion_total_chars, db_type) + self._ingestion_total_chars, db_type) def from_serialized(self, path: str): - assert len(self._vector_store.index_to_docstore_id) == 0, "Vector store already has documents" + assert len( + self._vector_store.index_to_docstore_id) == 0, "Vector store already has documents" # Pickled FAISS indexes reference a specific SWIG submodule (e.g. # `faiss.swigfaiss_avx512`). Alias any missing submodules so DBs built # on one host load on hosts with a different faiss-cpu build. @@ -722,8 +776,8 @@ def from_serialized(self, path: str): with open(path, "rb") as f: data = f.read() self._vector_store = FAISS.deserialize_from_bytes(embeddings=self._embedding_model, - serialized=data, - allow_dangerous_deserialization=True) # <--- USE WITH CAUTION - Only deserialize files you trust + serialized=data, + allow_dangerous_deserialization=True) # <--- USE WITH CAUTION - Only deserialize files you trust # Load parent map if hierarchical mode if self._hierarchical: @@ -733,6 +787,8 @@ def from_serialized(self, path: str): if parent_map_path.exists(): with open(parent_map_path, 'rb') as f: self._parent_map = pickle.load(f) - print(f"✓ Loaded parent map ({len(self._parent_map)} entries) from {parent_map_path}") + print( + f"✓ Loaded parent map ({len(self._parent_map)} entries) from {parent_map_path}") else: - print(f"⚠️ Warning: Hierarchical mode enabled but no parent map found at {parent_map_path}") + print( + f"⚠️ Warning: Hierarchical mode enabled but no parent map found at {parent_map_path}") diff --git a/e2e-rag/single_shot_retrieval.py b/e2e-rag/single_shot_retrieval.py index 81e0d69c4c..aa1afe54ad 100644 --- a/e2e-rag/single_shot_retrieval.py +++ b/e2e-rag/single_shot_retrieval.py @@ -26,7 +26,8 @@ from utils import set_deterministic_seeds, setup_llm_config from params import add_all_args -# Taken below from frames: https://huggingface.co/datasets/google/frames-benchmark +# Taken below from frames: +# https://huggingface.co/datasets/google/frames-benchmark DEFAULT_QUERY = "Who won the French Open Mens Singles tournament the year that New York City FC won their first MLS Cup title?" MAX_PASSAGE_PREVIEW = 4096 FULL_DOC_MAX_CHARS = 39000 @@ -58,7 +59,8 @@ def _read_text(path: str) -> str: return path_obj.read_text(encoding="utf-8", errors="ignore") -def _load_document_text(metadata, base_dir=None, default_base_dir="doc_html", max_chars=FULL_DOC_MAX_CHARS): +def _load_document_text(metadata, base_dir=None, + default_base_dir="doc_html", max_chars=FULL_DOC_MAX_CHARS): target_dir = base_dir or default_base_dir base_filename = metadata.get("base_filename") if not base_filename: @@ -79,7 +81,8 @@ def _load_document_text(metadata, base_dir=None, default_base_dir="doc_html", ma return "", None -def _convert_results_to_entries(results, limit=5, full_doc=False, base_dir=None, default_base_dir="doc_html"): +def _convert_results_to_entries( + results, limit=5, full_doc=False, base_dir=None, default_base_dir="doc_html"): entries = [] seen_ids = set() count = 0 @@ -96,9 +99,15 @@ def _convert_results_to_entries(results, limit=5, full_doc=False, base_dir=None, default_base_dir=default_base_dir ) if not content: - content = getattr(doc, "page_content", metadata.get("content", ""))[:MAX_PASSAGE_PREVIEW] + content = getattr( + doc, "page_content", metadata.get( + "content", ""))[ + :MAX_PASSAGE_PREVIEW] else: - content = getattr(doc, "page_content", metadata.get("content", ""))[:MAX_PASSAGE_PREVIEW] + content = getattr( + doc, "page_content", metadata.get( + "content", ""))[ + :MAX_PASSAGE_PREVIEW] source_path = None entry = {"url": url, "content": content} if source_path: @@ -130,7 +139,8 @@ def _generate_llm_answer(query, doc_entries, llm_config): source = doc.get("url") or "Unknown source" snippet = doc.get("content", "").strip() context_parts.append(f"[{idx}] Source: {source}\n{snippet}") - evidence_block = "\n\n".join(context_parts) if context_parts else "No supporting documents were retrieved." + evidence_block = "\n\n".join( + context_parts) if context_parts else "No supporting documents were retrieved." user_prompt = ( "Answer the question using only the provided evidence." " Respond with a single word or short phrase, or 'Unknown' if the evidence is insufficient.\n\n" @@ -153,7 +163,10 @@ def _generate_llm_answer(query, doc_entries, llm_config): "temperature": 0.0, "max_tokens": max_tokens } - response = requests.post(llm_config["service_url"], json=payload, timeout=60) + response = requests.post( + llm_config["service_url"], + json=payload, + timeout=60) response.raise_for_status() data = response.json() message = data["choices"][0]["message"] @@ -164,14 +177,15 @@ def _generate_llm_answer(query, doc_entries, llm_config): return content.strip() if content.strip() else "Unknown" - if __name__ == "__main__": - args = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) - + args = argparse.ArgumentParser( + formatter_class=argparse.RawTextHelpFormatter) + # Add all parameters from centralized definitions - # This includes: Common, General, BM25, Vector, Strategy, and Reranking parameters + # This includes: Common, General, BM25, Vector, Strategy, and Reranking + # parameters add_all_args(args) - + # Special handling for --eval argument (needs custom type) # Override the default eval argument with custom type for action in args._actions: @@ -190,17 +204,19 @@ def _generate_llm_answer(query, doc_entries, llm_config): args.database = VectorDB.get_default_db_name() # Normalize database path: ensure .db extension for file operations - db_file_path = args.database if args.database.endswith('.db') else f"{args.database}.db" - db_base_name = args.database.replace('.db', '') if args.database.endswith('.db') else args.database + db_file_path = args.database if args.database.endswith( + '.db') else f"{args.database}.db" + db_base_name = args.database.replace( + '.db', '') if args.database.endswith('.db') else args.database # Create VectorDB instance (pass base name without .db) rag_db = VectorDB(retriever_model=args.retriever_model, reranker_model=args.reranker_model, device=args.device, - database=db_base_name, - load_embeddings=args.load_embeddings, num_embedding_devices=args.num_embedding_devices, - hierarchical=args.hierarchical, - embedding_device=args.embedding_device, - reranker_device=args.reranker_device, - benchmark=args.benchmark) + database=db_base_name, + load_embeddings=args.load_embeddings, num_embedding_devices=args.num_embedding_devices, + hierarchical=args.hierarchical, + embedding_device=args.embedding_device, + reranker_device=args.reranker_device, + benchmark=args.benchmark) if os.path.exists(db_file_path): # Load existing database @@ -208,18 +224,21 @@ def _generate_llm_answer(query, doc_entries, llm_config): rag_db.from_serialized(db_file_path) else: if not args.ingest: - raise ValueError("Either --database (existing) or --ingest (to create new) must be provided") - + raise ValueError( + "Either --database (existing) or --ingest (to create new) must be provided") + # Ingest from file or folder tic = time.time() rag_db.ingest_from_path(args.ingest) - + # Get number of passages for timing calculation - num_passages = len(rag_db._doc_list) # This should be available after ingestion + # This should be available after ingestion + num_passages = len(rag_db._doc_list) toc = time.time() - ingestion_speed = num_passages/(toc-tic) - print(f"Ingestion of {num_passages} passages took {toc - tic:.2f} seconds. {ingestion_speed:.2f} docs/sec") - + ingestion_speed = num_passages / (toc - tic) + print( + f"Ingestion of {num_passages} passages took {toc - tic:.2f} seconds. {ingestion_speed:.2f} docs/sec") + # Save the database (unless --no-save is specified) if not args.no_save: print(f"Saving database to {db_file_path}") @@ -229,16 +248,20 @@ def _generate_llm_answer(query, doc_entries, llm_config): # Run evaluation or single query lookup if args.eval: - max_queries = args.eval if isinstance(args.eval, int) and not isinstance(args.eval, bool) and args.eval > 0 else None - - # Build strategy_params with correct parameter names for filter function + max_queries = args.eval if isinstance( + args.eval, int) and not isinstance( + args.eval, bool) and args.eval > 0 else None + + # Build strategy_params with correct parameter names for filter + # function strategy_params = {"max_results": args.max_results} if args.retrieval_strategy == "top_p": strategy_params["p"] = args.top_p elif args.retrieval_strategy == "relative": strategy_params["ratio"] = args.relative_ratio - + answer_records = [] + def handle_result(prompt, retrieved_docs, metrics): urls = _extract_unique_urls(retrieved_docs) answer_text = None @@ -249,7 +272,8 @@ def handle_result(prompt, retrieved_docs, metrics): full_doc=args.full_doc_context, base_dir=doc_base_dir ) - answer_text = _generate_llm_answer(prompt, doc_entries, llm_config) + answer_text = _generate_llm_answer( + prompt, doc_entries, llm_config) print(f"LLM Answer: {answer_text}") if args.save_results: record = { @@ -270,19 +294,21 @@ def handle_result(prompt, retrieved_docs, metrics): retrieval_strategy=args.retrieval_strategy, detailed_analysis=True, difficulty=args.difficulty, - result_handler=handle_result if (args.generate_answer or args.save_results) else None, + result_handler=handle_result if ( + args.generate_answer or args.save_results) else None, **strategy_params ) - + # Save results for optimization results_data = { - "accuracy": metrics.get('legacy_score', 0.0), # Backward compatibility + # Backward compatibility + "accuracy": metrics.get('legacy_score', 0.0), "metrics": metrics } - + with open("results.json", "w") as f: json.dump(results_data, f, indent=2) - + if args.save_results: with open("result_single_shot.json", "w") as f: json.dump({ @@ -292,13 +318,13 @@ def handle_result(prompt, retrieved_docs, metrics): exit(0) # Exit after evaluation else: # Single query lookup - reuse evaluation code for consistency - + strategy_params = {} if args.retrieval_strategy == "top_p": strategy_params["p"] = args.top_p elif args.retrieval_strategy == "relative": strategy_params["ratio"] = args.relative_ratio - + # Time the retrieval tic = time.time() need_results = args.generate_answer or args.save_results @@ -329,7 +355,8 @@ def handle_result(prompt, retrieved_docs, metrics): full_doc=args.full_doc_context, base_dir=doc_base_dir ) - answer_value = _generate_llm_answer(args.query, doc_entries, llm_config) + answer_value = _generate_llm_answer( + args.query, doc_entries, llm_config) print(f"LLM Answer: {answer_value}") if args.save_results: @@ -345,5 +372,5 @@ def handle_result(prompt, retrieved_docs, metrics): "results": [record] }, f, indent=2) toc = time.time() - + print(f"\nLookup took {toc - tic:.3f} seconds") diff --git a/e2e-rag/text_splitter.py b/e2e-rag/text_splitter.py index 14c09a6752..a83f89718d 100644 --- a/e2e-rag/text_splitter.py +++ b/e2e-rag/text_splitter.py @@ -30,37 +30,39 @@ def clean_text(text: str) -> str: """Clean and normalize text.""" # Normalize whitespace text = re.sub(r'\s+', ' ', text.strip()) - + # Remove excessive newlines but preserve paragraph structure text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text) - + return text -def find_sentence_boundary(text: str, start: int, end: int, search_window: int = 100) -> int: +def find_sentence_boundary(text: str, start: int, end: int, + search_window: int = 100) -> int: """ Find the best sentence boundary within the search window. - + Args: text: The text to search in start: Start position of the passage end: Desired end position search_window: Number of characters to look back for sentence boundary - + Returns: Best boundary position """ if end >= len(text): return len(text) - + # Look for sentence endings within the search window search_start = max(start, end - search_window) sentence_endings = ['.', '!', '?', '\n'] - + best_break = end for i in range(end - 1, search_start - 1, -1): if text[i] in sentence_endings: - # Check if it's followed by whitespace and uppercase letter (proper sentence end) + # Check if it's followed by whitespace and uppercase letter (proper + # sentence end) if i + 1 < len(text) and text[i + 1].isspace(): # Look for the next non-whitespace character j = i + 1 @@ -69,11 +71,12 @@ def find_sentence_boundary(text: str, start: int, end: int, search_window: int = if j < len(text) and (text[j].isupper() or text[j].isdigit()): best_break = i + 1 break - + return best_break -def split_into_passages(text: str, max_length: int = 512, overlap: int = 50) -> List[str]: +def split_into_passages(text: str, max_length: int = 512, + overlap: int = 50) -> List[str]: """ Split text into passages suitable for retrieval systems like ColBERT. @@ -87,7 +90,7 @@ def split_into_passages(text: str, max_length: int = 512, overlap: int = 50) -> """ # Clean up the text text = clean_text(text) - + if len(text) <= max_length: return [text] if text else [] @@ -97,7 +100,8 @@ def split_into_passages(text: str, max_length: int = 512, overlap: int = 50) -> while start < len(text): end = start + max_length - # If we're not at the end of the text, try to break at a sentence boundary + # If we're not at the end of the text, try to break at a sentence + # boundary if end < len(text): end = find_sentence_boundary(text, start, end) @@ -113,7 +117,8 @@ def split_into_passages(text: str, max_length: int = 512, overlap: int = 50) -> return passages -def split_into_fixed_passages(text: str, fixed_length: int = 256, overlap: int = 32) -> List[str]: +def split_into_fixed_passages( + text: str, fixed_length: int = 256, overlap: int = 32) -> List[str]: """ Split text into fixed-length passages with exact character counts. Useful for consistent passage lengths across datasets. @@ -128,7 +133,7 @@ def split_into_fixed_passages(text: str, fixed_length: int = 256, overlap: int = """ # Clean up the text text = clean_text(text) - + if len(text) <= fixed_length: return [text] if text else [] @@ -138,7 +143,7 @@ def split_into_fixed_passages(text: str, fixed_length: int = 256, overlap: int = while start < len(text): end = min(start + fixed_length, len(text)) passage = text[start:end].strip() - + if passage: passages.append(passage) @@ -150,15 +155,16 @@ def split_into_fixed_passages(text: str, fixed_length: int = 256, overlap: int = return passages -def create_passage_metadata(filename: str, passage_index: int, original_url: Optional[str] = None) -> dict: +def create_passage_metadata( + filename: str, passage_index: int, original_url: Optional[str] = None) -> dict: """ Create standardized metadata for a passage. - + Args: filename: Source filename (PDF or HTML) passage_index: Index of this passage within the document original_url: Original URL if available - + Returns: Dictionary containing passage metadata """ @@ -166,19 +172,20 @@ def create_passage_metadata(filename: str, passage_index: int, original_url: Opt base_filename = filename if '.' in filename: base_filename = '.'.join(filename.split('.')[:-1]) - + metadata = { 'index': passage_index, 'base_filename': base_filename } - + if original_url: metadata['original_url'] = original_url - + return metadata -def estimate_passage_count(text: str, max_length: int = 512, overlap: int = 50) -> int: +def estimate_passage_count( + text: str, max_length: int = 512, overlap: int = 50) -> int: """ Estimate the number of passages that will be created from text. Useful for progress tracking without actually splitting. @@ -227,7 +234,8 @@ def split_into_hierarchical_passages( if len(text) <= parent_length: # Single parent case - still create child chunks parent_text = text - children = split_into_passages(parent_text, child_length, child_overlap) + children = split_into_passages( + parent_text, child_length, child_overlap) results = [] for child_idx, child_text in enumerate(children): @@ -246,7 +254,8 @@ def split_into_hierarchical_passages( all_results = [] for parent_id, parent_text in enumerate(parent_chunks): # Split parent into children - children = split_into_passages(parent_text, child_length, child_overlap) + children = split_into_passages( + parent_text, child_length, child_overlap) # Create hierarchical entries for child_idx, child_text in enumerate(children): @@ -257,4 +266,4 @@ def split_into_hierarchical_passages( 'child_index': child_idx }) - return all_results \ No newline at end of file + return all_results diff --git a/e2e-rag/utils.py b/e2e-rag/utils.py index 6867c2aad3..50b9bddb05 100644 --- a/e2e-rag/utils.py +++ b/e2e-rag/utils.py @@ -27,7 +27,6 @@ from typing import Dict, Optional, Union, Any - def load_url_mapping(directory: str) -> Dict[str, str]: """Load URL mapping from url_mapping.json in specified directory.""" mapping_path = Path(directory) / "url_mapping.json" @@ -70,27 +69,28 @@ def set_deterministic_seeds(seed: int = 42) -> None: def filter_dataset_by_difficulty(df, difficulty: int = 0): """ Filter dataset by minimum number of answer links (difficulty level). - + Args: df: pandas DataFrame with dataset difficulty: Minimum number of answer links required (0 = no filtering) - + Returns: Filtered DataFrame with queries having >= difficulty answer links """ if difficulty <= 0: return df - + # Count answer links for each row link_counts = df.apply( - lambda row: sum(1 for col in df.columns - if col.startswith('wikipedia_link_') and row.notna()[col]), + lambda row: sum(1 for col in df.columns + if col.startswith('wikipedia_link_') and row.notna()[col]), axis=1 ) - + filtered_df = df[link_counts >= difficulty].reset_index(drop=True) - print(f"Filtered dataset by difficulty >= {difficulty}: {len(filtered_df)} queries remaining (from {len(df)} total)") - + print( + f"Filtered dataset by difficulty >= {difficulty}: {len(filtered_df)} queries remaining (from {len(df)} total)") + return filtered_df @@ -156,7 +156,8 @@ def set_mempolicy_membind(node: int) -> None: libnuma = ctypes.CDLL("libnuma.so.1", use_errno=True) rc = libnuma.set_mempolicy(MPOL_BIND, ctypes.byref(nodemask), maxnode) except (OSError, AttributeError): - # set_mempolicy is syscall 238 on x86_64; 237 on aarch64 (rare in this codebase). + # set_mempolicy is syscall 238 on x86_64; 237 on aarch64 (rare in this + # codebase). SYS_SET_MEMPOLICY_X86_64 = 238 libc = ctypes.CDLL("libc.so.6", use_errno=True) rc = libc.syscall(SYS_SET_MEMPOLICY_X86_64, MPOL_BIND, @@ -237,7 +238,8 @@ def pin_worker_to_node(node: int, cpu_set: list) -> None: if "OMP_NUM_THREADS" not in os.environ: os.environ["OMP_NUM_THREADS"] = str(len(cpu_set)) - print(f" [worker] node={node} cores={cpu_set[0]}..{cpu_set[-1]} ({len(cpu_set)}) OMP_NUM_THREADS={os.environ['OMP_NUM_THREADS']}") + print( + f" [worker] node={node} cores={cpu_set[0]}..{cpu_set[-1]} ({len(cpu_set)}) OMP_NUM_THREADS={os.environ['OMP_NUM_THREADS']}") def apply_numa_pinning() -> None: @@ -275,7 +277,8 @@ def apply_numa_pinning() -> None: print(f" sched_setaffinity failed: {e}; skipping pinning") return - print(f" Pinned CPU affinity to {len(cores)} cores: {cores[0]}..{cores[-1]}") + print( + f" Pinned CPU affinity to {len(cores)} cores: {cores[0]}..{cores[-1]}") def apply_cpu_threading_env() -> None: @@ -394,17 +397,20 @@ def _parse_override(self, override_env: str) -> list: ) return indices - def allocate(self, count: int = 1, name: str = "", override_env: str = "") -> list: + def allocate(self, count: int = 1, name: str = "", + override_env: str = "") -> list: if override_env: requested = self._parse_override(override_env) if requested: avail = [i for i in requested if i not in self._taken] source = f"{override_env}={','.join(map(str, requested))}" else: - avail = [i for i in self._all_indices if i not in self._taken and self._is_empty(i)] + avail = [ + i for i in self._all_indices if i not in self._taken and self._is_empty(i)] source = "auto" else: - avail = [i for i in self._all_indices if i not in self._taken and self._is_empty(i)] + avail = [ + i for i in self._all_indices if i not in self._taken and self._is_empty(i)] source = "auto" if len(avail) < count: @@ -417,7 +423,8 @@ def allocate(self, count: int = 1, name: str = "", override_env: str = "") -> li chosen = avail[:count] self._taken.update(chosen) label = name or self.device_type - print(f" Allocated {self.device_type}:{chosen} for {label} (via {source})") + print( + f" Allocated {self.device_type}:{chosen} for {label} (via {source})") return chosen @@ -428,14 +435,16 @@ def get_device_allocator(device_type: str) -> DeviceAllocator: return _DEVICE_ALLOCATORS[device_type] -def resolve_gpu_device(device: str, name: str = "", override_env: str = "") -> str: +def resolve_gpu_device(device: str, name: str = "", + override_env: str = "") -> str: """Map a bare device type ('cuda' / 'xpu') to a specific 'cuda:N' string. Returns `device` unchanged for cpu/hpu/auto/already-indexed strings. Errors if no empty GPU is available (use override_env to override). """ if device in ("cuda", "xpu"): - idx = get_device_allocator(device).allocate(count=1, name=name, override_env=override_env)[0] + idx = get_device_allocator(device).allocate( + count=1, name=name, override_env=override_env)[0] return f"{device}:{idx}" return device @@ -447,7 +456,8 @@ def detect_device() -> str: if torch.cuda.is_available(): if getattr(torch.version, "hip", None): - print(f"Using AMD ROCm GPU (torch.version.hip={torch.version.hip})") + print( + f"Using AMD ROCm GPU (torch.version.hip={torch.version.hip})") else: print("Using NVIDIA CUDA GPU") return "cuda" @@ -473,12 +483,14 @@ def get_model_info_from_service(service_url: str) -> Optional[Dict]: """Get model information from LLM service.""" try: # Try OpenAI-compatible API first - models_response = requests.get(f"{service_url.rstrip('/v1/chat/completions').rstrip('/v1')}/v1/models", timeout=10) + models_response = requests.get( + f"{service_url.rstrip('/v1/chat/completions').rstrip('/v1')}/v1/models", + timeout=10) if models_response.status_code == 200: models_data = models_response.json() if "data" in models_data and len(models_data["data"]) > 0: return models_data["data"][0] - + # Try alternative endpoints base_url = service_url.rstrip('/v1/chat/completions').rstrip('/v1') for endpoint in ["/models", "/info", "/v1/model"]: @@ -486,19 +498,19 @@ def get_model_info_from_service(service_url: str) -> Optional[Dict]: response = requests.get(f"{base_url}{endpoint}", timeout=5) if response.status_code == 200: return response.json() - except: + except BaseException: continue - + except Exception as e: print(f"Warning: Could not auto-detect model from {service_url}: {e}") - + return None def get_model_name_from_service(service_url: str) -> str: """Auto-detect model name from LLM service.""" model_info = get_model_info_from_service(service_url) - + if model_info: # Try different possible fields for model name for field in ["id", "model", "name", "model_name"]: @@ -512,18 +524,20 @@ def get_model_name_from_service(service_url: str) -> str: def get_max_tokens_from_service(service_url: str) -> int: """Auto-detect max tokens from LLM service.""" model_info = get_model_info_from_service(service_url) - + if model_info: # Try different possible fields for max tokens - for field in ["max_tokens", "max_length", "context_length", "max_context_length"]: + for field in ["max_tokens", "max_length", + "context_length", "max_context_length"]: if field in model_info and isinstance(model_info[field], int): return model_info[field] - + # Default fallback based on common models return 10240 -def resolve_config_value(value: Union[str, int], auto_func, *args) -> Union[str, int]: +def resolve_config_value( + value: Union[str, int], auto_func, *args) -> Union[str, int]: """Resolve configuration value that might be 'auto'.""" if value == "auto": return auto_func(*args) @@ -537,21 +551,22 @@ def get_device_config(): "device_count": 1, "device_memory": None } - + if torch is None: return config - + if config["device_type"] == "hpu": config["device_count"] = torch.hpu.device_count() - + elif config["device_type"] == "cuda": config["device_count"] = torch.cuda.device_count() if torch.cuda.is_available(): - config["device_memory"] = torch.cuda.get_device_properties(0).total_memory - + config["device_memory"] = torch.cuda.get_device_properties( + 0).total_memory + elif config["device_type"] == "xpu": config["device_count"] = torch.xpu.device_count() - + return config @@ -589,7 +604,8 @@ def setup_llm_config(args): grader_service_url = getattr(args, 'grader_service_url', None) or base_url grader_model_name = getattr(args, 'grader_model', None) or model_name query_service_url = getattr(args, 'query_service_url', None) or base_url - sufficiency_service_url = getattr(args, 'sufficiency_service_url', None) or base_url + sufficiency_service_url = getattr( + args, 'sufficiency_service_url', None) or base_url return { "service_url": base_url, diff --git a/loadgen/issue_query_controller.cc b/loadgen/issue_query_controller.cc index 4c5ca66f0c..c1abea9d14 100644 --- a/loadgen/issue_query_controller.cc +++ b/loadgen/issue_query_controller.cc @@ -459,8 +459,8 @@ void IssueQueryController::IssueQueriesInternal(size_t query_stride, #if USE_NEW_LOGGING_FORMAT std::stringstream ss; ss << "IssueQueryThread " << thread_idx - << " Ending early: Too many outstanding queries." - << " issued " << queries_issued_total << " outstanding " + << " Ending early: Too many outstanding queries." << " issued " + << queries_issued_total << " outstanding " << queries_outstanding; MLPERF_LOG_ERROR(detail, "error_runtime", ss.str()); #else @@ -499,8 +499,8 @@ void IssueQueryController::IssueQueriesInternal(size_t query_stride, #if USE_NEW_LOGGING_FORMAT std::stringstream ss; ss << "IssueQueryThread " << thread_idx - << " Ending early: Max query count reached." - << " query_count " << queries_issued; + << " Ending early: Max query count reached." << " query_count " + << queries_issued; MLPERF_LOG_ERROR(detail, "error_runtime", ss.str()); #else detail.Error("IssueQueryThread ", std::to_string(thread_idx), @@ -519,8 +519,8 @@ void IssueQueryController::IssueQueriesInternal(size_t query_stride, #if USE_NEW_LOGGING_FORMAT std::stringstream ss; ss << "IssueQueryThread " << thread_idx - << " Ending early: Max test duration reached." - << " duration_ns " << duration.count(); + << " Ending early: Max test duration reached." << " duration_ns " + << duration.count(); MLPERF_LOG_ERROR(detail, "error_runtime", ss.str()); #else detail.Error("IssueQueryThread ", std::to_string(thread_idx), diff --git a/loadgen/logging.cc b/loadgen/logging.cc index d7e83e54b9..807c1954a8 100644 --- a/loadgen/logging.cc +++ b/loadgen/logging.cc @@ -812,8 +812,7 @@ void Logger::CollectTlsLoggerStats(TlsLogger* tls_logger) { if (max_entry_vector_size > kTlsLogReservedEntryCount) { #if USE_NEW_LOGGING_FORMAT std::stringstream msg; - msg << "Logging allocation detected:" - << " tid: " << tls_logger->Tid() + msg << "Logging allocation detected:" << " tid: " << tls_logger->Tid() << " reserved_entries: " << kTlsLogReservedEntryCount << " max_entries: " << max_entry_vector_size; MLPERF_LOG_WARNING((*this), "warning_generic_message", msg.str());