-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdatabase.py
More file actions
38 lines (29 loc) · 1.12 KB
/
database.py
File metadata and controls
38 lines (29 loc) · 1.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import datetime
from typing import Optional, List
from sqlmodel import Field, SQLModel
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from pgvector.sqlalchemy import Vector
from sqlalchemy import Column
from config import config
from uuid import UUID, uuid4
class UserPersona(SQLModel, table=True):
id: Optional[UUID] = Field(default_factory=uuid4, primary_key=True)
username: str
persona: str
embedding: List[float] = Field(sa_column=Column(Vector(1536))) # openai embedding dim is 1536 in dafault model 'text-embedding-ada-002'
created_at: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
engine = create_async_engine(config.db_url, echo=False)
AsyncSessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine, class_=AsyncSession
)
async def get_db_session():
async with AsyncSessionLocal() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()