Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ wheels/
.venv

# custom
chroma_db/
.chroma_db
.env
.ipynb_checkpoints
4 changes: 3 additions & 1 deletion app.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def main():
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
try:
answer, sources = st.session_state.chatbot.ask_question(prompt)
answer, sources = st.session_state.chatbot.ask_question(
prompt, thread_id=st.session_state.thread_id
)
st.markdown(answer)

if sources:
Expand Down
53 changes: 46 additions & 7 deletions chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
PyPDFLoader,
TextLoader,
)
from langchain_core.documents import Document
from langchain_core.messages import SystemMessage
from langchain_core.tools import tool
from langchain_huggingface import HuggingFaceEmbeddings
Expand Down Expand Up @@ -39,7 +40,7 @@ def __init__(self):
self.vector_store = Chroma(
collection_name="docs",
embedding_function=embeddings,
persist_directory="./chroma_db",
persist_directory="./.chroma_db",
)

print("Init ToolNode...")
Expand Down Expand Up @@ -112,6 +113,7 @@ def retrieve(query: str):
(f"Source: {doc.metadata}\n" f"Content: {doc.page_content}")
for doc in retrieved_docs
)
retrieved_docs = [doc.__dict__ for doc in retrieved_docs]
return serialized, retrieved_docs

return retrieve
Expand All @@ -127,7 +129,7 @@ def query_or_respond(self, state: MessagesState) -> Dict:
_type_: _description_
"""

print(f"State messages in query_or_respond: {state['messages']}")
print(f"State messages in query_or_respond: {state['messages']}\n")

messages = state["messages"]
if not any(msg.type == "system" for msg in messages):
Expand Down Expand Up @@ -235,11 +237,48 @@ def ask_question(self, question: str, thread_id: str = None) -> Tuple[str, List]
print(f"Response from graph: {response}\n")
print(f"Latest response: {latest_response}\n")

retrieved_docs = []
for msg in response["messages"]:
if msg.type == "tool" and hasattr(msg, "artifact") and msg.artifact:
retrieved_docs.extend(msg.artifact)
latest_tool_msg = next(
(
msg
for msg in reversed(response["messages"])
if msg.type == "tool" and hasattr(msg, "artifact") and msg.artifact
),
None,
)

print(f"Total retrieved documents: {len(retrieved_docs)}")
retrieved_docs = []
if latest_tool_msg:
for artifact in latest_tool_msg.artifact:
if isinstance(artifact, dict):
doc = Document(
id=artifact["id"],
page_content=artifact["page_content"],
metadata=artifact["metadata"],
page_content_type=artifact["page_content"],
)
else:
doc = artifact
retrieved_docs.append(doc)

# retrieved_docs = []
# for msg in response["messages"]:
# if msg.type == "tool" and hasattr(msg, "artifact") and msg.artifact:
# for artifact in msg.artifact:
# if isinstance(artifact, dict):
# # Convert dict to Document if necessary
# doc = Document(
# id=artifact["id"],
# page_content=artifact["page_content"],
# metadata=artifact["metadata"],
# page_content_type=artifact["page_content"],
# )
# else:
# doc = artifact

# retrieved_docs.append(doc)

print(f"Total retrieved documents: {len(retrieved_docs)}\n")
print(f"Retrieved documents: {retrieved_docs}\n")
print("=" * 50)

return latest_response.content, retrieved_docs
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ readme = "README.md"
requires-python = ">=3.11"
dependencies = [
"torch==2.7.1",
"langchain-community==0.3.1",
"langchain-community==0.3.27",
"langchain-anthropic",
"langchain-chroma",
"langchain-huggingface",
Expand Down
3 changes: 2 additions & 1 deletion run.sh
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
source .venv/bin/activate
streamlit run app.py
rm -r .chroma_db
streamlit run app.py
Loading