Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions data_pipeline/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,9 @@ async def run_full_pipeline(
logger.info("💾 Step 6: Saving dataframes to CSV files...")
try:
self._save_dataframes(dataframes, name=timestamp)
except Exception:
import pdb

pdb.set_trace()
except Exception as e:
logger.error("Failed to save dataframes: %s", e)
raise

# Step 7: Combine related works
all_related_works = glob.glob(f"{self.config.output_dir}/related_works/*.csv")
Expand Down
1 change: 1 addition & 0 deletions deepscholar_base/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class Config:

# Common for both agentic and recursive search
enable_web_search: bool = True
enable_exa_search: bool = False
per_query_max_search_results_count: int = 10

# Only for agentic search
Expand Down
185 changes: 185 additions & 0 deletions deepscholar_base/search/agentic_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@
from openai.types.shared import Reasoning
import re
from lotus import web_search, web_extract, WebSearchCorpus
from exa_py import Exa

try:
from deepscholar_base.utils.prompts import (
openai_sdk_arxiv_search_system_prompt,
openai_sdk_arxiv_search_system_prompt_without_cutoff,
openai_sdk_search_system_prompt,
openai_sdk_search_system_prompt_without_cutoff,
openai_sdk_exa_search_system_prompt,
openai_sdk_exa_search_system_prompt_without_cutoff,
)
from deepscholar_base.configs import Configs
except ImportError:
Expand All @@ -33,6 +36,8 @@
openai_sdk_arxiv_search_system_prompt_without_cutoff,
openai_sdk_search_system_prompt,
openai_sdk_search_system_prompt_without_cutoff,
openai_sdk_exa_search_system_prompt,
openai_sdk_exa_search_system_prompt_without_cutoff,
)
from ..configs import Configs

Expand Down Expand Up @@ -61,6 +66,7 @@ def merge_papers_df(self, new: pd.DataFrame) -> None:
class ToolTypes(Enum):
ARXIV = "arxiv"
WEB = "web"
EXA = "exa"

def to_web_search_corpus(self) -> WebSearchCorpus:
if self == ToolTypes.ARXIV:
Expand All @@ -75,6 +81,8 @@ def to_rename_map(self) -> dict[str, str]:
return {"link": "url", "abstract": "snippet", "published": "date"}
elif self == ToolTypes.WEB:
return {"content": "snippet"}
elif self == ToolTypes.EXA:
return {}
else:
raise ValueError(f"Invalid search type: {self}")

Expand Down Expand Up @@ -293,6 +301,174 @@ async def read_webpage_full_text(
"""
return await _read_content(ctx, ToolTypes.WEB, urls)


# ---------- Exa Search Functions ----------
def _exa_search(
queries: list[str],
max_results: int = 10,
end_date: datetime | None = None,
) -> pd.DataFrame:
exa = Exa()
all_results = []
for query in queries:
try:
search_kwargs = {
"query": query,
"num_results": max_results,
"type": "auto",
"category": "research paper",
}
if end_date:
search_kwargs["end_published_date"] = end_date.strftime("%Y-%m-%dT%H:%M:%S.000Z")
results = exa.search(**search_kwargs)
for r in results.results:
all_results.append({
"title": r.title or "",
"url": r.url or "",
"snippet": r.text if hasattr(r, "text") and r.text else "",
"date": r.published_date or "",
"query": query,
})
except Exception:
continue
if not all_results:
return pd.DataFrame(columns=["title", "url", "snippet", "date", "query"])
return pd.DataFrame(all_results)


def _exa_get_contents(
urls: list[str],
) -> pd.DataFrame:
exa = Exa()
all_results = []
try:
results = exa.get_contents(urls, text=True)
for r in results.results:
all_results.append({
"title": r.title or "",
"url": r.url or "",
"full_text": r.text if hasattr(r, "text") and r.text else "",
})
except Exception:
pass
if not all_results:
return pd.DataFrame(columns=["title", "url", "full_text"])
return pd.DataFrame(all_results)


@function_tool
async def search_exa(ctx: RunContextWrapper[AgentContext], queries: list[str]) -> str:
"""
Search for academic papers and research using Exa's neural search engine.
Exa excels at finding research papers, technical blog posts, and scholarly content
that may not be indexed on arXiv.

Returns up to 10 entries per query, with clear separation showing which results
correspond to which query. Each entry is formatted as "Title (date): URL".

Guidelines for Constructing Effective Exa Search Queries:
- Use natural language queries that describe the concept you are looking for.
- Be specific about the research area, methodology, or finding.
- Use different queries to cover distinct aspects of the topic.

Args:
queries: A list of search query strings.
Example: ["retrieval augmented generation for question answering", "dense passage retrieval methods"]
"""
ctx.context.configs.logger.info(f"Searching Exa for queries: {queries}")
cutoff = ctx.context.end_date
all_results_sections = []
successful_queries_list = []

exa_df = await asyncio.get_event_loop().run_in_executor(
None,
lambda: _exa_search(
queries,
max_results=ctx.context.configs.per_query_max_search_results_count,
end_date=cutoff,
),
)

for query in queries:
query_df = exa_df[exa_df["query"] == query] if not exa_df.empty else pd.DataFrame()
if query_df.empty:
all_results_sections.append(f"=== QUERY: {query} ===\nNo results found.")
continue

query_df["context"] = query_df.apply(
lambda row: f"{row.get('title', '')}[{row.get('url', '')}]: {row.get('snippet', '')}", axis=1
)
required_columns = ["title", "url", "snippet", "query", "context", "date"]
for col in required_columns:
if col not in query_df.columns:
query_df[col] = ""

ctx.context.merge_papers_df(query_df[required_columns])
successful_queries_list.append(query)
results_text = "\n".join(
f"{row.get('title', 'Untitled')} ({row.get('date', '')}): {row.get('url', '')}"
for _, row in query_df.iterrows()
)
all_results_sections.append(f"=== QUERY: {query} ===\n{results_text}")

ctx.context.queries.append(["exa_search"] + successful_queries_list)
ctx.context.configs.logger.info(
f"Exa search successful queries: {successful_queries_list}, "
f"collected total references: {len(ctx.context.papers_df) if ctx.context.papers_df is not None else 0}"
)
return "\n\n".join(all_results_sections)


@function_tool
async def read_exa_contents(
ctx: RunContextWrapper[AgentContext], urls: list[str]
) -> str:
"""
Retrieve the full text content for a list of URLs using Exa's content extraction.
Use this after `search_exa` surfaces promising URLs to get detailed content
for synthesis. The output includes each page's URL, title, and extracted text.

Args:
urls: A list of URLs to extract content from.
Example: ["https://arxiv.org/abs/2301.00001", "https://example.com/paper"]
"""
ctx.context.configs.logger.info(f"Reading Exa content for urls: {urls}")
successful_inputs = ["exa_read"]
try:
contents_df = await asyncio.get_event_loop().run_in_executor(
None, lambda: _exa_get_contents(urls)
)
if contents_df.empty:
ctx.context.queries.append(successful_inputs)
return "No content found for the provided URLs."

results_text = []
for _, row in contents_df.iterrows():
full_text = row.get("full_text", "")
truncated = full_text[:1000] if full_text else ""
results_text.append(f"{row.get('url', '')}\n\n{truncated}")
paper_row = pd.DataFrame([{
"title": row.get("title", ""),
"url": row.get("url", ""),
"snippet": full_text,
"query": "exa_read",
"context": f"[{row.get('url', '')}]: {truncated}",
"date": "",
}])
ctx.context.merge_papers_df(paper_row)
successful_inputs.append(row.get("url", ""))

ctx.context.queries.append(successful_inputs)
ctx.context.configs.logger.info(
f"Exa read successful inputs: {successful_inputs}, "
f"collected total references: {len(ctx.context.papers_df) if ctx.context.papers_df is not None else 0}"
)
return "\n\n---\n\n".join(results_text) if results_text else "No content found"
except Exception as e:
ctx.context.queries.append(successful_inputs)
return f"Error extracting content from Exa for {urls}: {e}"


def _call_model_input_filter(input: CallModelData[AgentContext]) -> ModelInputData:
"""
This function is used to trim input to less than search_lm's max_ctx_len.
Expand Down Expand Up @@ -374,6 +550,15 @@ async def agentic_search(
if not end_date
else openai_sdk_search_system_prompt
)
if configs.enable_exa_search:
configs.logger.info("Exa search is enabled, adding Exa search tools and prompt.")
tools.append(search_exa)
tools.append(read_exa_contents)
prompt = (
openai_sdk_exa_search_system_prompt_without_cutoff
if not end_date
else openai_sdk_exa_search_system_prompt
)
agent = Agent(
name="Research Assistant",
instructions=prompt,
Expand Down
106 changes: 106 additions & 0 deletions deepscholar_base/search/recursive_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import time
from lotus import web_search
from datetime import datetime, timedelta
from exa_py import Exa

try:
from deepscholar_base.utils.summary_generation import generate_section_summary
from deepscholar_base.utils.prompts import (
Expand Down Expand Up @@ -62,6 +64,20 @@ async def recursive_search(
else:
results = pd.concat([results, web_results])
results.fillna("", inplace=True)
if configs.enable_exa_search:
exa_queries, exa_results = await _exa_multiquery_search(
web_multiquery_system_prompt,
topic,
background,
configs,
end_date,
)
queries.extend(exa_queries)
if results is None:
results = exa_results
elif exa_results is not None and not exa_results.empty:
results = pd.concat([results, exa_results])
results.fillna("", inplace=True)
if results is not None:
results = results.drop_duplicates(subset=["url"])
else:
Expand Down Expand Up @@ -226,6 +242,96 @@ def generate_context(row):
return df


########### Exa Search ###########
async def _exa_multiquery_search(
instruction: str,
topic: str,
background: str,
configs: Configs,
end_date: datetime | None = None,
) -> tuple[list[str], pd.DataFrame]:
queries = await _generate_queries(
topic, background, instruction, end_date, configs
)
configs.logger.info(f"Searching Exa for queries: {queries}")
results = await _safe_exa_async_search(configs, queries, configs.per_query_max_search_results_count, end_date=end_date)
return queries, results


async def _safe_exa_async_search(
configs: Configs,
queries: list[str],
K: int,
end_date: datetime | None = None,
) -> pd.DataFrame:
dfs = await asyncio.gather(
*[_process_single_exa_search(configs, query, K, end_date) for query in queries]
)
if dfs:
result_df = pd.concat(dfs, ignore_index=True).fillna("")
else:
result_df = pd.DataFrame()
if not result_df.empty and "url" in result_df.columns:
result_df = result_df.drop_duplicates(subset=["url"])
return result_df


async def _process_single_exa_search(
configs: Configs,
query: str,
K: int,
end_date: datetime | None = None,
) -> pd.DataFrame:
required_columns = ["title", "url", "snippet", "query", "context", "date"]
count = 5
while count > 0:
try:
exa = Exa()
search_kwargs = {
"query": query,
"num_results": K,
"type": "auto",
"category": "research paper",
}
if end_date:
search_kwargs["end_published_date"] = end_date.strftime("%Y-%m-%dT%H:%M:%S.000Z")
results = await asyncio.get_event_loop().run_in_executor(
None, lambda: exa.search(**search_kwargs)
)
rows = []
for r in results.results:
rows.append({
"title": r.title or "",
"url": r.url or "",
"snippet": r.text if hasattr(r, "text") and r.text else "",
"date": r.published_date or "",
"query": query,
})
if not rows:
return pd.DataFrame(columns=required_columns)
df = pd.DataFrame(rows)
df["context"] = df.apply(
lambda row: f"{row.get('title', '')}[{row.get('url', '')}]: {row.get('snippet', '')}", axis=1
)
for col in required_columns:
if col not in df.columns:
df[col] = ""
if end_date is not None and "date" in df.columns:
try:
df["_date"] = pd.to_datetime(df["date"], errors="coerce")
cutoff_date = end_date - timedelta(days=1)
df = df[df["_date"].dt.date <= cutoff_date.date()]
df.drop(columns=["_date"], inplace=True)
except Exception as e:
configs.logger.error(f"Error processing Exa date: {e}")
return df
except Exception as e:
configs.logger.error(f"Error in Exa search for query {query}, attempt {6 - count}/5: {e}")
time.sleep(1)
count -= 1
return pd.DataFrame(columns=required_columns)


########### Generate Queries ###########
class Queries(BaseModel):
queries: list[str] = Field(
Expand Down
Loading