diff --git a/search_r1/search/retrieval_server.py b/search_r1/search/retrieval_server.py index f3969898..8c93f86a 100644 --- a/search_r1/search/retrieval_server.py +++ b/search_r1/search/retrieval_server.py @@ -337,11 +337,11 @@ def retrieve_endpoint(request: QueryRequest): if not request.topk: request.topk = config.retrieval_topk # fallback to default - # Perform batch retrieval + # Perform batch retrieval (always fetch scores internally, filter at response stage) results, scores = retriever.batch_search( query_list=request.queries, num=request.topk, - return_score=request.return_scores + return_score=True ) # Format response