-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdb.py
More file actions
163 lines (120 loc) · 5.07 KB
/
db.py
File metadata and controls
163 lines (120 loc) · 5.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# db.py
import json, os, requests
from datetime import datetime
from sqlmodel import SQLModel, Field, create_engine, Session, select, update, Column, JSON
from typing import List, Dict, Any, Literal
from pydantic import BaseModel
# Initialize SQLite database
DB_PATH = "databases/ai_service.db"
VDB_MODEL = "mxbai-embed-large"
engine = create_engine(f"sqlite:///{DB_PATH}")
def check_host(host="172.24.192.1", port=11434, timeout=3):
try:
response = requests.get(
f"http://{host}:{port}/",
timeout=timeout,
)
return response.status_code == 200
except Exception:
return False
# Determine the working host
VDB_HOST = "172.24.192.1" if check_host("172.24.192.1", 11434) else "127.0.0.1"
print("VDB_HOST", VDB_HOST)
class LogEntry(SQLModel, table=True):
__tablename__: str = "logs"
id: int = Field(default=None, primary_key=True)
timestamp: datetime = Field(default_factory=datetime.now)
request: str
response: str
meta: dict = Field(sa_column=Column(JSON))
type: str | None
is_rag: bool = Field(default=False)
class Message(BaseModel):
"""Schema for chat messages."""
role: Literal["system", "user", "assistant"]
content: str
def to_dict(self) -> Dict[str, str]:
"""Convert message to OpenAI-compatible format."""
return {"role": self.role, "content": self.content}
_profiles = Literal["external_fastest", "external_smartest", "private_fast", "private_balanced", "private_smart"]
class QueryRequest(BaseModel):
"""Schema for query requests."""
prompt: str
profile: _profiles = "private_fast"
class ChatMessage(BaseModel):
role: Literal["system", "user", "assistant"]
content: str
class ChatRequest(BaseModel):
messages: List[ChatMessage]
profile: _profiles | None = "external_fastest"
# Initialize ChromaDB
class VectorDB:
def __init__(self, collection_name: str = "flow_dataset"):
import chromadb
from chromadb.api.types import EmbeddingFunction, Embeddings
self.chroma_client = chromadb.PersistentClient(path="./databases/chromadb")
self.collection = self.chroma_client.get_or_create_collection(
collection_name,
metadata={"hnsw:space": "cosine", "hnsw:search_ef": 400},
)
def embed(self, text=None, model=VDB_MODEL):
response = requests.post(
f"http://{VDB_HOST}:11434/api/embeddings",
json={"model": model, "prompt": text},
headers={"Content-Type": "application/json"},
)
if response.status_code == 200:
return response.json()["embedding"]
else:
raise Exception(f"Embed Request failed with status code {response.status_code}: {response.text}")
def add(self, id: str, text: str, metadata: Dict[str, Any] | None = None):
"""Add a document to the collection."""
self.collection.add(ids=[id], documents=[text], embeddings=[self.embed(text)], metadatas=[metadata] if metadata else None)
def query(self, text: str, n=3):
"""Query the collection for similar documents."""
return self.collection.query(query_embeddings=[self.embed(text)], n_results=n)
def empty(self):
alldocs = self.collection.peek(1000)
if alldocs["ids"]:
self.collection.delete(ids=alldocs["ids"])
def init_db():
os.makedirs("./databases", exist_ok=True)
if not os.path.exists(DB_PATH):
SQLModel.metadata.create_all(engine)
def safe_migrate_db():
"""Safely migrate the database by adding any missing columns."""
from sqlalchemy import inspect
inspector = inspect(engine)
if not inspector.has_table("logs"):
SQLModel.metadata.create_all(engine)
return
existing_columns = {column["name"] for column in inspector.get_columns("logs")}
required_columns = {column.name for column in LogEntry.__table__.columns}
missing_columns = required_columns - existing_columns
if missing_columns:
with engine.begin() as conn:
for column in missing_columns:
col_type = LogEntry.__table__.columns[column].type
nullable = LogEntry.__table__.columns[column].nullable
default = LogEntry.__table__.columns[column].default
alter_stmt = f"ALTER TABLE logs ADD COLUMN {column} {col_type}"
if not nullable:
if default is not None:
alter_stmt += f" DEFAULT {default.arg if hasattr(default, 'arg') else default}"
alter_stmt += " NOT NULL"
elif default is not None:
alter_stmt += f" DEFAULT {default.arg if hasattr(default, 'arg') else default}"
print(f"Adding missing column: {column}")
conn.exec_driver_sql(alter_stmt)
def get_session():
return Session(engine)
init_db()
safe_migrate_db()
vdb = VectorDB()
if __name__ == "__main__":
# Example usage
with Session(engine) as session:
logs = session.exec(select(LogEntry).limit(2)).all()
for log in logs:
print(log)
session.commit()