Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions backend/agents/reactive_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import json
from api_endpoints.financeGPT.chatbot_endpoints import (
get_relevant_chunks, add_message_to_db, add_sources_to_db,
retrieve_message_from_db, retrieve_docs_from_db
retrieve_message_from_db, retrieve_docs_from_db, serialize_sources_for_api
)
from .config import AgentConfig
from .multi_agent_system import MultiAgentDocumentSystem
Expand Down Expand Up @@ -651,12 +651,13 @@ def stream_callback(event):

# Try to extract sources from the agent's reasoning early
sources = self._extract_sources_from_response(response, chat_id, user_email, query)
sources_payload = serialize_sources_for_api(sources)

# Yield the completion event that frontend expects
yield {
"type": "complete",
"answer": answer,
"sources": sources if sources else [],
"sources": sources_payload,
"thought": final_thought,
"timestamp": self._get_timestamp()
}
Expand Down Expand Up @@ -730,7 +731,7 @@ def stream_callback(event):
yield {
"type": "step-complete",
"answer": answer,
"sources": sources if sources else [],
"sources": sources_payload,
"thought": "Processing complete - response ready for user",
"timestamp": self._get_timestamp()
}
Expand Down Expand Up @@ -764,7 +765,7 @@ def stream_callback(event):
"type": "response-complete",
"answer": answer,
"message_id": message_id,
"sources": sources if sources else [],
"sources": sources_payload,
"message": "Response generated and saved successfully",
"total_steps": len(final_reasoning_steps),
"agent_reasoning": response.get("intermediate_steps", []),
Expand Down Expand Up @@ -1133,7 +1134,7 @@ def process_query(self, query: str, chat_id: int, user_email: str) -> Dict[str,
def _extract_sources_from_response(self, response: Dict, chat_id: int, user_email: str, query: str) -> List[tuple]:
try:
# Try to get sources from the document retrieval that was likely used
sources = get_relevant_chunks(2, query, chat_id, user_email)
sources = get_relevant_chunks(2, query, chat_id, user_email, include_metadata=True)
if sources and sources != ["No text found"]:
return sources
except Exception as e:
Expand Down Expand Up @@ -1179,4 +1180,4 @@ def process_workflow_query(self, query: str, workflow_id: int, user_email: str)
return completion.completion

except Exception as e:
return f"Error processing workflow query: {str(e)}"
return f"Error processing workflow query: {str(e)}"
49 changes: 49 additions & 0 deletions backend/api_endpoints/financeGPT/chatbot_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,55 @@
prepare_chunks_for_embedding = finance_gpt_service.prepare_chunks_for_embedding


def serialize_sources_for_api(sources):
serialized_sources = []
for index, source in enumerate(sources or []):
if isinstance(source, dict):
serialized_sources.append(
{
"id": f"source-{index}",
"document_name": source.get("document_name", "Unknown document"),
"chunk_text": source.get("chunk_text", ""),
"page_number": source.get("page_number"),
"start_index": source.get("start_index"),
"end_index": source.get("end_index"),
"source_type": source.get("source_type", "document_chunk"),
}
)
continue

if len(source) < 2:
continue

serialized_sources.append(
{
"id": f"source-{index}",
"document_name": source[1],
"chunk_text": source[0],
"page_number": source[2] if len(source) > 2 and isinstance(source[2], int) else None,
"start_index": source[3] if len(source) > 3 else None,
"end_index": source[4] if len(source) > 4 else None,
"source_type": "document_chunk",
}
)

return serialized_sources


def sources_to_prompt_context(sources):
parts = []
for source in serialize_sources_for_api(sources):
location = (
f" (page {source['page_number']})"
if source.get("page_number") is not None
else ""
)
parts.append(
f"Document: {source['document_name']}{location}: {source['chunk_text']}"
)
return " ".join(parts)


def access_sharable_chat(share_uuid, user_id=1):
new_chat_id = access_shareable_chat(share_uuid, user_id)
if new_chat_id is None:
Expand Down
31 changes: 20 additions & 11 deletions backend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@
get_text_from_url,
)
from tika import parser as p
from api_endpoints.financeGPT.chatbot_endpoints import (
serialize_sources_for_api,
sources_to_prompt_context,
)

_get_model()
from datetime import datetime
Expand Down Expand Up @@ -864,8 +868,8 @@ def _process_message_pdf_fallback(message, chat_id, model_type, model_key, user_
})

#Get most relevant section from the document
sources = get_relevant_chunks(2, query, chat_id, user_email)
sources_str = " ".join([", ".join(str(elem) for elem in source) for source in sources])
sources = get_relevant_chunks(2, query, chat_id, user_email, include_metadata=True)
sources_str = sources_to_prompt_context(sources)

if (model_type == 0):
if model_key:
Expand Down Expand Up @@ -921,7 +925,11 @@ def _process_message_pdf_fallback(message, chat_id, model_type, model_key, user_
except:
print("no sources")

return jsonify(answer=answer)
return jsonify(
answer=answer,
message_id=message_id,
sources=serialize_sources_for_api(sources),
)


@app.route('/add-model-key', methods=['POST'])
Expand Down Expand Up @@ -1082,13 +1090,12 @@ def public_ingest_pdf(): # pragma: no cover
agent = ReactiveDocumentAgent(model_type=model_type, model_key=model_key)
result = agent.process_query(message.strip(), chat_id, user_email)

# Format sources for compatibility
sources_swapped = [[str(elem) for elem in source[::-1]] for source in result.get("sources", [])]
sources_payload = serialize_sources_for_api(result.get("sources", []))

return jsonify(
message_id=result.get("message_id"),
answer=result["answer"],
sources=sources_swapped
sources=sources_payload
)

except Exception as e:
Expand All @@ -1110,10 +1117,8 @@ def _public_chat_fallback(message, chat_id, model_type, model_key, user_email):
add_message_to_db(query, chat_id, 1)

#Get most relevant section from the document
sources = get_relevant_chunks(2, query, chat_id, user_email)
sources_str = " ".join([", ".join(str(elem) for elem in source) for source in sources])

sources_swapped = [[str(elem) for elem in source[::-1]] for source in sources]
sources = get_relevant_chunks(2, query, chat_id, user_email, include_metadata=True)
sources_str = sources_to_prompt_context(sources)

if (model_type == 0):
if model_key:
Expand Down Expand Up @@ -1164,7 +1169,11 @@ def _public_chat_fallback(message, chat_id, model_type, model_key, user_email):
except:
print("no sources")

return jsonify(message_id=message_id, answer=answer, sources=sources_swapped)
return jsonify(
message_id=message_id,
answer=answer,
sources=serialize_sources_for_api(sources),
)

@app.route('/public/evaluate', methods = ['POST'])
@valid_api_key_required
Expand Down
12 changes: 9 additions & 3 deletions backend/database/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -1338,9 +1338,15 @@ def get_document_content(document_id, email):
def _combine_sources(sources):
combined_sources = ""
for source in sources:
if len(source) < 2:
continue
chunk_text, document_name = source[0], source[1]
if isinstance(source, dict):
chunk_text = source.get("chunk_text")
document_name = source.get("document_name")
if not chunk_text or not document_name:
continue
else:
if len(source) < 2:
continue
chunk_text, document_name = source[0], source[1]
combined_sources += f"Document: {document_name}: {chunk_text}\n\n"
return combined_sources

Expand Down
21 changes: 15 additions & 6 deletions backend/services/finance_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def knn(query_vector, document_vectors):
]


def get_relevant_chunks(k, question, chat_id, user_email):
def get_relevant_chunks(k, question, chat_id, user_email, include_metadata=False):
rows = get_chat_chunks(user_email, chat_id)
chunk_embeddings = []
chunk_metadata = []
Expand All @@ -432,6 +432,7 @@ def get_relevant_chunks(k, question, chat_id, user_email):
{
"start": row["start_index"],
"end": row["end_index"],
"page_number": row.get("page_number"),
"document_name": row["document_name"],
"document_text": row["document_text"],
}
Expand All @@ -452,12 +453,20 @@ def get_relevant_chunks(k, question, chat_id, user_email):
source_chunks = []
for index in range(min(k, len(results))):
metadata = chunk_metadata[results[index]["index"]]
source_chunks.append(
(
metadata["document_text"][metadata["start"] : metadata["end"]],
metadata["document_name"],
chunk_text = metadata["document_text"][metadata["start"] : metadata["end"]]
if include_metadata:
source_chunks.append(
{
"chunk_text": chunk_text,
"document_name": metadata["document_name"],
"page_number": metadata["page_number"],
"start_index": metadata["start"],
"end_index": metadata["end"],
"source_type": "document_chunk",
}
)
)
else:
source_chunks.append((chunk_text, metadata["document_name"]))
return source_chunks


Expand Down
13 changes: 13 additions & 0 deletions backend/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,19 @@ def remote(self, *args: Any, **kwargs: Any) -> None:
finance_module.chunk_document = _RemoteCallable()
finance_module.add_document_to_db = lambda *args, **kwargs: (1, False)
finance_module.get_relevant_chunks = lambda *args, **kwargs: [("chunk", "doc", 1)]
finance_module.serialize_sources_for_api = lambda sources: [
{
"id": source.get("id", f"source-{index}") if isinstance(source, dict) else f"source-{index}",
"document_name": source.get("document_name", "Unknown document") if isinstance(source, dict) else source[1],
"chunk_text": source.get("chunk_text", "") if isinstance(source, dict) else source[0],
"page_number": source.get("page_number") if isinstance(source, dict) else (source[2] if len(source) > 2 else None),
"start_index": source.get("start_index") if isinstance(source, dict) else (source[3] if len(source) > 3 else None),
"end_index": source.get("end_index") if isinstance(source, dict) else (source[4] if len(source) > 4 else None),
"source_type": source.get("source_type", "document_chunk") if isinstance(source, dict) else "document_chunk",
}
for index, source in enumerate(sources or [])
]
finance_module.sources_to_prompt_context = lambda sources: " ".join(str(source) for source in (sources or []))
finance_module.create_chat_shareable_url = lambda chat_id: f"/playbook/{chat_id}"
finance_module.access_sharable_chat = lambda playbook_url: {"url": playbook_url}
finance_module._get_model = lambda: None
Expand Down
13 changes: 13 additions & 0 deletions backend/tests/test_finance_gpt_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,15 @@ def test_get_relevant_chunks(monkeypatch: pytest.MonkeyPatch) -> None:
{
"start_index": 0,
"end_index": 4,
"page_number": 1,
"embedding_vector": vector,
"document_name": "doc-a",
"document_text": "abcdefgh",
},
{
"start_index": 4,
"end_index": 8,
"page_number": 2,
"embedding_vector": vector,
"document_name": "doc-b",
"document_text": "ijklmnop",
Expand All @@ -118,6 +120,17 @@ def test_get_relevant_chunks(monkeypatch: pytest.MonkeyPatch) -> None:
assert len(sources) == 1
assert sources[0][1] in {"doc-a", "doc-b"}

structured_sources = finance_gpt.get_relevant_chunks(
1,
"question",
9,
"user@example.com",
include_metadata=True,
)
assert len(structured_sources) == 1
assert structured_sources[0]["document_name"] in {"doc-a", "doc-b"}
assert structured_sources[0]["source_type"] == "document_chunk"


def test_get_relevant_chunks_handles_embedding_failure(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(finance_gpt, "get_chat_chunks", lambda user_email, chat_id: [])
Expand Down
31 changes: 29 additions & 2 deletions backend/tests/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,20 @@ def test_public_chat_and_evaluate(client: Any, app_module: Any, monkeypatch: pyt
monkeypatch.setattr(app_module, "ensure_SDK_user_exists", lambda user_email: 1)
monkeypatch.setattr(app_module.AgentConfig, "is_agent_enabled", staticmethod(lambda: False))
monkeypatch.setattr(app_module, "get_chat_info", lambda chat_id: (0, 0, "Chat"))
monkeypatch.setattr(app_module, "get_relevant_chunks", lambda *args: [("chunk", "doc", 1)])
monkeypatch.setattr(
app_module,
"get_relevant_chunks",
lambda *args, **kwargs: [
{
"chunk_text": "chunk",
"document_name": "doc",
"page_number": 1,
"start_index": 0,
"end_index": 5,
"source_type": "document_chunk",
}
],
)
monkeypatch.setattr(app_module, "add_message_to_db", lambda *args, **kwargs: 10)
monkeypatch.setattr(app_module, "add_sources_to_db", lambda *args, **kwargs: None)
monkeypatch.setattr(
Expand All @@ -544,7 +557,21 @@ def test_public_chat_and_evaluate(client: Any, app_module: Any, monkeypatch: pyt
json={"chat_id": 9, "message": "hello"},
)
assert chat_response.status_code == 200
assert chat_response.get_json() == {"message_id": 10, "answer": "openai answer", "sources": [["1", "doc", "chunk"]]}
assert chat_response.get_json() == {
"message_id": 10,
"answer": "openai answer",
"sources": [
{
"id": "source-0",
"document_name": "doc",
"chunk_text": "chunk",
"page_number": 1,
"start_index": 0,
"end_index": 5,
"source_type": "document_chunk",
}
],
}

evaluate_response = client.post(
"/public/evaluate",
Expand Down
Loading
Loading