-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
339 lines (278 loc) · 12.1 KB
/
main.py
File metadata and controls
339 lines (278 loc) · 12.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
"""Main example demonstrating the PocketFlow Database Query Agent."""
import os
import sys
from typing import Dict, Any, Optional
# Load environment variables from .env file
from dotenv import load_dotenv
load_dotenv()
# Add current directory to path for imports
sys.path.insert(0, '.')
from core.agent import DatabaseAgent, create_agent
from config import AgentConfig
def create_sample_database() -> str:
"""Create a sample SQLite database for demonstration."""
import sqlite3
from datetime import datetime, timedelta
import random
db_path = "sample_business.db"
# Remove existing database
if os.path.exists(db_path):
os.remove(db_path)
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# Create tables
cursor.execute("""
CREATE TABLE customers (
customer_id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
email TEXT UNIQUE NOT NULL,
city TEXT,
country TEXT DEFAULT 'USA',
registration_date DATE NOT NULL,
is_active BOOLEAN DEFAULT 1
)
""")
cursor.execute("""
CREATE TABLE products (
product_id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
category TEXT NOT NULL,
price DECIMAL(10,2) NOT NULL,
stock_quantity INTEGER DEFAULT 0,
description TEXT,
created_date DATE DEFAULT CURRENT_DATE
)
""")
cursor.execute("""
CREATE TABLE orders (
order_id INTEGER PRIMARY KEY AUTOINCREMENT,
customer_id INTEGER NOT NULL,
order_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
status TEXT CHECK(status IN ('pending', 'processing', 'shipped', 'delivered', 'cancelled')),
total_amount DECIMAL(10,2),
shipping_address TEXT,
FOREIGN KEY (customer_id) REFERENCES customers (customer_id)
)
""")
cursor.execute("""
CREATE TABLE order_items (
order_item_id INTEGER PRIMARY KEY AUTOINCREMENT,
order_id INTEGER NOT NULL,
product_id INTEGER NOT NULL,
quantity INTEGER NOT NULL CHECK (quantity > 0),
unit_price DECIMAL(10,2) NOT NULL,
FOREIGN KEY (order_id) REFERENCES orders (order_id),
FOREIGN KEY (product_id) REFERENCES products (product_id)
)
""")
# Insert sample data
customers_data = [
('Alice Johnson', 'alice@email.com', 'New York', 'USA', '2023-01-15', 1),
('Bob Smith', 'bob@email.com', 'Los Angeles', 'USA', '2023-02-20', 1),
('Charlie Brown', 'charlie@email.com', 'Chicago', 'USA', '2023-03-10', 1),
('Diana Ross', 'diana@email.com', 'Houston', 'USA', '2023-04-05', 1),
('Eve Wilson', 'eve@email.com', 'Phoenix', 'USA', '2023-05-12', 0),
]
cursor.executemany(
"INSERT INTO customers (name, email, city, country, registration_date, is_active) VALUES (?, ?, ?, ?, ?, ?)",
customers_data
)
products_data = [
('Laptop Pro', 'Electronics', 1299.99, 25, 'High-performance laptop'),
('Wireless Mouse', 'Electronics', 29.99, 150, 'Ergonomic wireless mouse'),
('Office Chair', 'Furniture', 249.99, 45, 'Comfortable office chair'),
('Coffee Maker', 'Appliances', 89.99, 30, 'Automatic drip coffee maker'),
('Desk Lamp', 'Furniture', 39.99, 60, 'LED desk lamp with adjustable arm'),
('Smartphone', 'Electronics', 799.99, 40, 'Latest model smartphone'),
('Water Bottle', 'Accessories', 19.99, 200, 'Stainless steel water bottle'),
('Notebook Set', 'Office', 15.99, 100, 'Set of 3 lined notebooks'),
]
cursor.executemany(
"INSERT INTO products (name, category, price, stock_quantity, description) VALUES (?, ?, ?, ?, ?)",
products_data
)
# Generate orders
order_statuses = ['pending', 'processing', 'shipped', 'delivered', 'cancelled']
for i in range(20):
customer_id = random.randint(1, 5)
status = random.choice(order_statuses)
order_date = datetime.now() - timedelta(days=random.randint(1, 90))
cursor.execute(
"INSERT INTO orders (customer_id, order_date, status, shipping_address) VALUES (?, ?, ?, ?)",
(customer_id, order_date.isoformat(), status, f"{random.randint(100, 999)} Main St")
)
order_id = cursor.lastrowid
# Add order items
num_items = random.randint(1, 4)
total_amount = 0
for _ in range(num_items):
product_id = random.randint(1, 8)
quantity = random.randint(1, 3)
# Get product price
cursor.execute("SELECT price FROM products WHERE product_id = ?", (product_id,))
unit_price = cursor.fetchone()[0]
cursor.execute(
"INSERT INTO order_items (order_id, product_id, quantity, unit_price) VALUES (?, ?, ?, ?)",
(order_id, product_id, quantity, unit_price)
)
total_amount += quantity * unit_price
# Update order total
cursor.execute("UPDATE orders SET total_amount = ? WHERE order_id = ?", (total_amount, order_id))
conn.commit()
conn.close()
print(f"Sample database created: {db_path}")
return db_path
def run_examples(agent: DatabaseAgent):
"""Run example queries to demonstrate the agent."""
examples = [
"How many customers do we have?",
"Show me all products in the Electronics category",
"What are the top 5 customers by total order amount?",
"List all pending orders with customer details",
"What's the average order value for each city?",
"Show me products that are low on stock (less than 50 units)",
"How many orders were placed in the last 30 days?",
"What are our most popular product categories by sales volume?",
]
print("\n" + "="*60)
print("RUNNING EXAMPLE QUERIES")
print("="*60)
for i, query in enumerate(examples, 1):
print(f"\n--- Example {i}: {query} ---")
try:
result = agent.query(query, options={
'format': 'table',
'include_metadata': True
})
if result.success:
print(f"✅ Success! Generated SQL: {result.sql}")
print(f"📊 Results ({result.row_count} rows):")
if result.formatted_data:
print(result.formatted_data)
else:
# Print first few rows
for j, row in enumerate(result.data[:3] if result.data else []):
print(f" Row {j+1}: {row}")
if result.data and len(result.data) > 3:
print(f" ... and {len(result.data) - 3} more rows")
print(f"⏱️ Execution time: {result.execution_time:.3f}s")
if result.validation_warnings:
print(f"⚠️ Warnings: {', '.join(result.validation_warnings)}")
else:
print(f"❌ Failed: {result.error}")
except Exception as e:
print(f"💥 Error: {e}")
print("-" * 50)
def interactive_mode(agent: DatabaseAgent):
"""Run interactive query mode."""
print("\n" + "="*60)
print("INTERACTIVE MODE")
print("="*60)
print("Enter natural language queries (type 'quit' to exit, 'help' for commands)")
print("Available commands:")
print(" - help: Show this help message")
print(" - schema: Show database schema summary")
print(" - stats: Show agent statistics")
print(" - quit/exit: Exit interactive mode")
print()
session_id = agent.create_session()
while True:
try:
query = input("🔍 Query: ").strip()
if query.lower() in ['quit', 'exit']:
break
elif query.lower() == 'help':
print("Available commands:")
print(" - help: Show this help message")
print(" - schema: Show database schema summary")
print(" - stats: Show agent statistics")
print(" - quit/exit: Exit interactive mode")
continue
elif query.lower() == 'schema':
schema = agent.get_schema_summary()
print(f"📋 Schema Summary:")
print(f" Tables: {schema['table_count']}")
print(f" Total Columns: {schema['total_columns']}")
for table, info in schema['tables'].items():
print(f" - {table}: {info['columns']} columns, ~{info.get('row_count', 'unknown')} rows")
continue
elif query.lower() == 'stats':
stats = agent.get_statistics()
print(f"📊 Agent Statistics:")
for key, value in stats.items():
print(f" {key}: {value}")
continue
elif not query:
continue
# Process the query
result = agent.query(query, options={'format': 'table'}, session_id=session_id)
if result.success:
print(f"✅ SQL: {result.sql}")
if result.formatted_data:
print(result.formatted_data)
print(f"⏱️ Time: {result.execution_time:.3f}s | Rows: {result.row_count}")
else:
print(f"❌ Error: {result.error}")
except KeyboardInterrupt:
print("\n👋 Goodbye!")
break
except Exception as e:
print(f"💥 Unexpected error: {e}")
def main():
"""Main function demonstrating the Database Query Agent."""
print("🚀 PocketFlow Database Query Agent Demo")
print("=" * 50)
# Set up environment
if not os.getenv('OPENAI_API_KEY'):
print("⚠️ Warning: OPENAI_API_KEY not set. Using mock client for demo.")
os.environ['LLM_PROVIDER'] = 'mock'
# Create sample database
db_path = create_sample_database()
db_url = f"sqlite:///{db_path}"
# Create agent with configuration
print("\n📊 Initializing Database Agent...")
try:
# Create agent configuration
config = AgentConfig.from_env()
config.database.url = db_url
config.database.type = "sqlite"
config.security.allowed_operations = ["SELECT"] # Read-only for demo
config.cache.enabled = True
config.verbose = True
# Initialize agent
agent = DatabaseAgent(config=config)
# Validate setup
if agent.validate_connection():
print("✅ Agent initialized successfully!")
# Show schema summary
schema = agent.get_schema_summary()
print(f"\n📋 Database Schema:")
print(f" Tables: {schema['table_count']}")
print(f" Columns: {schema['total_columns']}")
# Run examples
if len(sys.argv) > 1 and sys.argv[1] == "--interactive":
interactive_mode(agent)
else:
run_examples(agent)
# Offer interactive mode
response = input("\n🤔 Would you like to try interactive mode? (y/n): ")
if response.lower().startswith('y'):
interactive_mode(agent)
else:
print("❌ Failed to validate agent setup")
except Exception as e:
print(f"💥 Failed to initialize agent: {e}")
return 1
finally:
# Cleanup
if 'agent' in locals():
agent.close()
# Optionally remove sample database
cleanup = input("\n🧹 Remove sample database? (y/n): ")
if cleanup.lower().startswith('y') and os.path.exists(db_path):
os.remove(db_path)
print(f"🗑️ Removed {db_path}")
print("\n👋 Thank you for trying PocketFlow Database Query Agent!")
return 0
if __name__ == "__main__":
sys.exit(main())