forked from analystanand/rag_webapp
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
191 lines (158 loc) · 7.81 KB
/
app.py
File metadata and controls
191 lines (158 loc) · 7.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import streamlit as st
import PyPDF2
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from sentence_transformers import SentenceTransformer
import numpy as np
class RAGComparison:
def __init__(self):
# Initialize models with explicit configuration
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
# Load tokenizer and model explicitly
self.zero_shot_tokenizer = AutoTokenizer.from_pretrained("gpt2")
self.zero_shot_model = AutoModelForCausalLM.from_pretrained("gpt2")
# Configure tokenizer
self.zero_shot_tokenizer.pad_token = self.zero_shot_tokenizer.eos_token
# QA Model with context
self.rag_model = pipeline(
"question-answering",
model="deepset/roberta-base-squad2",
tokenizer="deepset/roberta-base-squad2"
)
def extract_text_from_pdf(self, uploaded_file):
"""Extract text from uploaded PDF"""
pdf_reader = PyPDF2.PdfReader(uploaded_file)
text = ""
for page in pdf_reader.pages:
text += page.extract_text() or ""
return text
def chunk_text(self, text, chunk_size=200, overlap=50):
"""Split text into overlapping chunks"""
words = text.split()
chunks = []
for i in range(0, len(words), chunk_size - overlap):
chunk = " ".join(words[i:i+chunk_size])
chunks.append(chunk)
return chunks
def create_embeddings(self, chunks):
"""Create embeddings for text chunks"""
return self.embedding_model.encode(chunks)
def find_most_relevant_chunk(self, query, chunks, embeddings):
"""Find the most relevant chunk to the query"""
query_embedding = self.embedding_model.encode([query])[0]
similarities = np.dot(embeddings, query_embedding)
most_similar_idx = np.argmax(similarities)
return chunks[most_similar_idx]
def generate_zero_shot_response(self, query):
"""Generate response without context"""
# Explicitly handle tokenization and generation
inputs = self.zero_shot_tokenizer(
query,
return_tensors="pt",
truncation=True,
max_length=50
)
# Generate text
outputs = self.zero_shot_model.generate(
inputs.input_ids,
max_length=150,
num_return_sequences=1,
do_sample=True,
pad_token_id=self.zero_shot_tokenizer.eos_token_id
)
# Decode the generated text
response = self.zero_shot_tokenizer.decode(
outputs[0],
skip_special_tokens=True
)
return response
def generate_rag_response(self, context, query):
"""Generate response with context using RAG"""
result = self.rag_model({
'context': context,
'question': query
})
return result['answer']
def main():
st.title("Multi-Mode RAG Comparison")
# Sample Story
sample_story = """
# Titanic Survival Prediction Model Report
## Data Preprocessing
The dataset underwent careful preprocessing, which involved handling missing values, encoding categorical variables, and normalizing numerical features to prepare the data for machine learning analysis.
## Model: XGBoost Classifier
### Features and Importance
The XGBoost model revealed a clear hierarchy of feature importance in predicting passenger survival. Fare emerged as the most critical feature, accounting for the highest importance score of 0.352. This was closely followed by Age, which contributed 0.276 to the model's predictive power. Gender (Sex) ranked third with an importance score of 0.215, demonstrating its significant impact on survival chances. Passenger class (Pclass) contributed 0.097 to the model, while the embarkation point (Embarked) played a minor role with 0.038. Familial connections through Sibling/Spouse (SibSp) and Parent/Child (Parch) relationships had the least influence, with scores of 0.022 and 0.010 respectively.
### Model Performance Metrics
The XGBoost classifier demonstrated robust performance across multiple evaluation metrics. The model achieved an overall accuracy of 85%, with precision and recall both hovering around 0.83. The F1 score of 0.83 indicates a balanced performance between precision and recall, while the AUC-ROC score of 0.89 suggests strong predictive capabilities in distinguishing between survival and non-survival scenarios.
## Key Insights
- Fare and passenger class emerged as the most critical predictors of survival, indicating the significant role of socioeconomic factors during the Titanic disaster.
- Gender played a crucial role in survival probability, with clear disparities in rescue rates.
- Age was found to be the second most important feature, suggesting that a passenger's age significantly influenced their chances of survival.
## Recommendations
- To enhance the model's predictive power, researchers should consider incorporating more detailed passenger information.
- Exploring ensemble methods could potentially improve model performance.
- Additionally, validating the model against additional historical datasets would provide further robustness to the analysis.
"""
# Instantiate the RAG application
rag_app = RAGComparison()
# Input Method Selection
input_mode = st.sidebar.radio(
"Choose Input Mode",
["Default Story", "Direct Text Input", "PDF Upload"]
)
# Text input based on mode
if input_mode == "Default Story":
st.subheader("Sample Report")
st.markdown(sample_story)
text = sample_story
elif input_mode == "Direct Text Input":
text = st.text_area(
"Enter your text",
height=250,
placeholder="Paste the text you want to query..."
)
else: # PDF Upload
uploaded_file = st.file_uploader("Upload a PDF", type="pdf")
if uploaded_file is not None:
text = rag_app.extract_text_from_pdf(uploaded_file)
st.success("PDF Uploaded and Processed!")
else:
text = ""
# Proceed if text is available
if text:
# Chunk and embed text
chunks = rag_app.chunk_text(text)
embeddings = rag_app.create_embeddings(chunks)
# Predefined questions
questions = st.text_input("Enter your question about the text")
if questions:
# Comparison Container
col1, col2 = st.columns(2)
with col1:
st.subheader("Zero-Shot (Without RAG)")
st.warning("Response without context")
# Generate Zero-Shot Response
zero_shot_response = rag_app.generate_zero_shot_response(questions)
st.write(zero_shot_response)
with col2:
st.subheader("RAG Response")
st.success("Response with context")
# Find most relevant chunk
relevant_chunk = rag_app.find_most_relevant_chunk(questions, chunks, embeddings)
# Generate RAG Response
rag_response = rag_app.generate_rag_response(relevant_chunk, questions)
st.write(rag_response)
# Show Relevant Context
with st.expander("See Relevant Context"):
st.write(relevant_chunk)
# Additional information
st.sidebar.markdown("### How to Use")
st.sidebar.info(
"1. Choose input mode\n"
"2. Enter text or upload PDF\n"
"3. Ask a question\n"
"4. Compare Zero-Shot vs RAG responses"
)
if __name__ == "__main__":
main()