-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdbconnect.py
More file actions
123 lines (100 loc) · 4.25 KB
/
dbconnect.py
File metadata and controls
123 lines (100 loc) · 4.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
from sqlalchemy import create_engine, text, select, func
from sqlalchemy.orm import Session, selectinload
from dbmodel import Base, Document, DocumentChunk
POSTGRES_USER = os.environ["POSTGRES_USER"]
POSTGRES_PASSWORD = os.environ["POSTGRES_PASSWORD"]
POSTGRES_HOST = os.environ["POSTGRES_HOST"]
POSTGRES_PORT = os.environ["POSTGRES_PORT"]
POSTGRES_DB = os.environ["POSTGRES_DB"]
DB_URL = f"postgresql+psycopg://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_HOST}:{POSTGRES_PORT}/{POSTGRES_DB}"
engine = create_engine(DB_URL)
NUM_OF_SEARCH_RESULTS = int(os.environ.get("NUM_OF_DB_SEARCH_RESULTS", 20))
with engine.connect() as conn:
# Enable the extension
conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
conn.commit()
Base.metadata.create_all(engine)
# apply column additions idempotently
with engine.connect() as conn:
conn.execute(text("ALTER TABLE documents ADD COLUMN IF NOT EXISTS metadata_json JSONB"))
conn.commit()
def add_document(document: Document):
with Session(engine, expire_on_commit=False) as session:
session.add(document)
session.commit()
def update_document(document: Document):
with Session(engine, expire_on_commit=False) as session:
session.merge(document)
session.commit()
def get_document_by_url(url: str) -> Document | None:
with Session(engine, expire_on_commit=False) as session:
stmt = (
select(Document)
.options(selectinload(Document.chunks))
.where(Document.url == url)
)
return session.scalars(stmt).first()
def get_document_by_id(id: int) -> Document | None:
with Session(engine, expire_on_commit=False) as session:
stmt = (
select(Document)
.options(selectinload(Document.chunks))
.where(Document.id == id)
)
return session.scalars(stmt).first()
def search_documents(query_vector: list) -> list[DocumentChunk]:
# For Cosine Distance: 0.0 is identical, 1.0 is orthogonal (unrelated), 2.0 is opposite.
# A common strict threshold is 0.3 to 0.4.
DISTANCE_THRESHOLD = 0.4
with Session(engine) as session:
# Use cosine_distance (Cosine Similarity)
# We order by distance (ascending) and limit to top 5 results
query_results = session.scalars(select(DocumentChunk)
.where(DocumentChunk.content_vector.cosine_distance(query_vector) < DISTANCE_THRESHOLD)
.order_by(DocumentChunk.content_vector.cosine_distance(query_vector))
.limit(NUM_OF_SEARCH_RESULTS)).all()
results = []
for chunk in query_results:
results.append({
"title": chunk.document.title,
"url": chunk.document.url,
"chunk": chunk.content,
"id": f"{chunk.document.id}-{chunk.order_index}",
"document_id": chunk.document_id,
"order_index": chunk.order_index,
})
return results
return []
def get_documents(offset, page_size):
"""
This function retrieves documents from the database based on the provided offset and page size.
Args:
offset (int): The number of documents to skip.
page_size (int): The number of documents to include in each page.
Returns:
list[Document]: A list of documents along with their chunks.
"""
with Session(engine) as session:
stmt = (
select(Document)
.options(selectinload(Document.chunks))
.order_by(Document.id) # Consistency is key for pagination
.limit(page_size)
.offset(offset)
)
return session.scalars(stmt).all()
def get_documents_count():
"""Returns the total count of documents in the database."""
with Session(engine) as session:
count = session.execute(select(func.count()).select_from(Document)).scalar()
return count
def delete_document(document_id: int) -> bool:
"""Delete a document and its chunks by ID. Returns True if found and deleted."""
with Session(engine) as session:
doc = session.get(Document, document_id)
if doc is None:
return False
session.delete(doc)
session.commit()
return True