Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ flowchart LR
U[User] --> FE[Frontend - React Vite]
FE --> API[Flask API]
API --> SCHEMA[Schema - CSV or introspection]
API --> LLM[Gemini - LangChain]
API --> LLM[Gemini/OpenRouter - LangChain]
API --> DB[(SQL database)]
API --> R[(Redis - optional)]
API --> M[Metrics - Prometheus]
Expand Down Expand Up @@ -159,7 +159,7 @@ Base URL (local): `http://127.0.0.1:5000`
- Deployed on Vercel

**Backend**
- Flask + SQLAlchemy + LangChain + Gemini
- Flask + SQLAlchemy + LangChain + Gemini/OpenRouter
- Rate limiting + structured logging
- Optional: Redis + RQ + Prometheus

Expand Down Expand Up @@ -190,9 +190,19 @@ Create `.env` in the repo root (do **not** commit it):

```env
# --- LLM ---
LLM_PROVIDER=gemini # gemini | openrouter
GOOGLE_API_KEY=YOUR_GOOGLE_API_KEY
GEMINI_MODEL=gemini-2.5-flash

# --- OpenRouter (optional) ---
OPENROUTER_API_KEY=YOUR_OPENROUTER_API_KEY
OPENROUTER_MODEL=openai/gpt-4o-mini
OPENROUTER_EMBEDDINGS_MODEL=text-embedding-3-small
OPENROUTER_BASE_URL=https://openrouter.ai/api/v1
OPENROUTER_SITE_URL=
OPENROUTER_APP_NAME=AskDB
EMBEDDINGS_PROVIDER=gemini # gemini | openrouter

# --- Demo DB (default) ---
DATABASE_URL=postgresql+psycopg2://user:pass@host:5432/postgres?sslmode=require

Expand Down
13 changes: 13 additions & 0 deletions SETUP.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,21 @@ python -c "import langchain; import flask; print('Installation successful!')"
Create a `.env` file in the project root:

```bash
# LLM Provider
LLM_PROVIDER=gemini # gemini | openrouter

# Google AI Configuration
GOOGLE_API_KEY=your_google_api_key_here
GEMINI_MODEL=gemini-2.5-flash

# OpenRouter Configuration (optional)
OPENROUTER_API_KEY=your_openrouter_api_key_here
OPENROUTER_MODEL=openai/gpt-4o-mini
OPENROUTER_EMBEDDINGS_MODEL=text-embedding-3-small
OPENROUTER_BASE_URL=https://openrouter.ai/api/v1
OPENROUTER_SITE_URL=
OPENROUTER_APP_NAME=AskDB
EMBEDDINGS_PROVIDER=gemini # gemini | openrouter

# LangChain Configuration
LANGCHAIN_TRACING_V2=true
Expand Down
13 changes: 6 additions & 7 deletions TECHNICAL.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ AskDB is built on LangChain, leveraging its powerful framework for building AI a

#### **LLM Integration**
```python
# Google Gemini integration via LangChain
llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro", temperature=0)
# Provider-configurable LLM via LangChain (Gemini/OpenRouter)
LLM_PROVIDER=gemini # or openrouter
llm = get_llm()
```

#### **Prompt Management**
Expand Down Expand Up @@ -181,11 +182,9 @@ Final Answer

### Vector Embeddings
```python
# Uses Google's embedding model for semantic similarity
embeddings = GoogleGenerativeAIEmbeddings(
model="models/embedding-001",
task_type="retrieval_query"
)
# Provider-configurable embeddings for semantic similarity
EMBEDDINGS_PROVIDER=gemini # or openrouter
embeddings = get_embeddings()
```

### Structured Output
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ pandas==2.2.2
pydantic==2.7.4
langchain==0.2.6
langchain-community==0.2.6
langchain-openai==0.1.8
openai==1.30.1
google-genai==0.5.0
proto-plus==1.24.0
protobuf==4.25.3
Expand Down
131 changes: 104 additions & 27 deletions untitled0.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,77 @@
except Exception: # pragma: no cover
from langchain_classic.chains.sql_database.query import create_sql_query_chain
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_chroma import Chroma
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from pydantic import BaseModel, Field, ValidationError

load_dotenv()

# ------------------------------------------------------------
# LLM
# LLM / Embeddings
# ------------------------------------------------------------
LLM_PROVIDER = os.getenv("LLM_PROVIDER", "gemini").lower()
EMBEDDINGS_PROVIDER = os.getenv("EMBEDDINGS_PROVIDER", LLM_PROVIDER).lower()

GEMINI_MODEL = os.getenv("GEMINI_MODEL", "gemini-2.5-flash")
llm = ChatGoogleGenerativeAI(model=GEMINI_MODEL, temperature=0)

OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY", "")
OPENROUTER_MODEL = os.getenv("OPENROUTER_MODEL", "openai/gpt-4o-mini")
OPENROUTER_EMBEDDINGS_MODEL = os.getenv("OPENROUTER_EMBEDDINGS_MODEL", "text-embedding-3-small")
OPENROUTER_BASE_URL = os.getenv("OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1")
OPENROUTER_SITE_URL = os.getenv("OPENROUTER_SITE_URL", "")
OPENROUTER_APP_NAME = os.getenv("OPENROUTER_APP_NAME", "AskDB")

_llm = None
_embeddings = None

def _openrouter_headers() -> Optional[Dict[str, str]]:
headers: Dict[str, str] = {}
if OPENROUTER_SITE_URL:
headers["HTTP-Referer"] = OPENROUTER_SITE_URL
if OPENROUTER_APP_NAME:
headers["X-Title"] = OPENROUTER_APP_NAME
return headers or None

def get_llm():
global _llm
if _llm is not None:
return _llm
if LLM_PROVIDER == "openrouter":
if not OPENROUTER_API_KEY:
raise ValueError("OPENROUTER_API_KEY is required when LLM_PROVIDER=openrouter")
_llm = ChatOpenAI(
model=OPENROUTER_MODEL,
temperature=0,
openai_api_key=OPENROUTER_API_KEY,
base_url=OPENROUTER_BASE_URL,
default_headers=_openrouter_headers(),
)
return _llm
if LLM_PROVIDER == "gemini":
_llm = ChatGoogleGenerativeAI(model=GEMINI_MODEL, temperature=0)
return _llm
raise ValueError(f"Unsupported LLM_PROVIDER: {LLM_PROVIDER}")

def get_embeddings():
global _embeddings
if _embeddings is not None:
return _embeddings
if EMBEDDINGS_PROVIDER == "openrouter":
if not OPENROUTER_API_KEY:
raise ValueError("OPENROUTER_API_KEY is required when EMBEDDINGS_PROVIDER=openrouter")
_embeddings = OpenAIEmbeddings(
model=OPENROUTER_EMBEDDINGS_MODEL,
openai_api_key=OPENROUTER_API_KEY,
base_url=OPENROUTER_BASE_URL,
default_headers=_openrouter_headers(),
)
return _embeddings
if EMBEDDINGS_PROVIDER == "gemini":
_embeddings = GoogleGenerativeAIEmbeddings(model="models/gemini-embedding-001", task_type="retrieval_query")
return _embeddings
raise ValueError(f"Unsupported EMBEDDINGS_PROVIDER: {EMBEDDINGS_PROVIDER}")

# ------------------------------------------------------------
# Defaults / knobs
Expand Down Expand Up @@ -609,7 +669,7 @@ def build_few_shot_prompt() -> FewShotChatMessagePromptTemplate:
return static_prompt

try:
embeddings = GoogleGenerativeAIEmbeddings(model="models/gemini-embedding-001", task_type="retrieval_query")
embeddings = get_embeddings()
selector = SemanticSimilarityExampleSelector.from_examples(
examples=EXAMPLES,
embeddings=embeddings,
Expand Down Expand Up @@ -706,7 +766,7 @@ def select_tables_for_question(question: str, table_details: str) -> List[str]:

# LLM selection (best-effort)
try:
raw = (_TABLE_SELECT_PROMPT | llm | StrOutputParser()).invoke(
raw = (_TABLE_SELECT_PROMPT | get_llm() | StrOutputParser()).invoke(
{"question": question, "table_details": table_details}
)
picked = _parse_table_list(raw, allowed)
Expand Down Expand Up @@ -786,7 +846,10 @@ def _schema_rag_select(question: str, table_details: str, db_cache_key: str) ->
# Build / reuse vectorstore
vs_key = f"schema_vs::{db_cache_key}"
if vs_key not in _schema_vs_cache:
embeddings = GoogleGenerativeAIEmbeddings(model="models/gemini-embedding-001", task_type="retrieval_query")
try:
embeddings = get_embeddings()
except Exception:
return table_details
persist_dir = ".chroma_schema" if SCHEMA_RAG_PERSIST else None
vs = Chroma(
collection_name=f"askdb_schema_{hashlib.md5(db_cache_key.encode()).hexdigest()[:12]}",
Expand Down Expand Up @@ -824,7 +887,6 @@ def _schema_rag_select(question: str, table_details: str, db_cache_key: str) ->

Write the answer clearly and concisely in business-friendly language."""
)
rephrase_answer = answer_prompt | llm | StrOutputParser()

class ChartSpec(BaseModel):
summary: str = Field(description="1-2 sentence executive summary of the results")
Expand Down Expand Up @@ -912,7 +974,7 @@ def is_date_like(v) -> bool:
)
try:
msg = prompt.format_messages(q=question, cols=", ".join(cols), rows=json.dumps(sample, default=str))
raw = str(llm.invoke(msg).content).strip()
raw = str(get_llm().invoke(msg).content).strip()
raw = re.sub(r"^```(?:json)?\s*|\s*```$", "", raw, flags=re.DOTALL).strip()
data = json.loads(raw)
spec = ChartSpec(**data)
Expand All @@ -938,7 +1000,7 @@ def _repair_sql(question: str, table_info: str, dialect: str, mode: str, bad_sql
]
)
msg = prompt.format_messages(q=question, mode=mode, ti=table_info, sql=bad_sql, err=error_msg)
fixed = llm.invoke(msg).content
fixed = get_llm().invoke(msg).content
return clean_sql_query(str(fixed))

def _rewrite_sql_for_budget(question: str, table_info: str, dialect: str, mode: str, sql: str, explain: Dict[str, Any]) -> str:
Expand All @@ -958,7 +1020,7 @@ def _rewrite_sql_for_budget(question: str, table_info: str, dialect: str, mode:
]
)
msg = prompt.format_messages(q=question, mode=mode, ti=table_info, sql=sql, ex=json.dumps(explain))
out = llm.invoke(msg).content
out = get_llm().invoke(msg).content
return clean_sql_query(str(out))

def _rewrite_sql_for_speed(question: str, table_info: str, dialect: str, mode: str, sql: str, db_ms: int) -> str:
Expand All @@ -978,7 +1040,7 @@ def _rewrite_sql_for_speed(question: str, table_info: str, dialect: str, mode: s
]
)
msg = prompt.format_messages(q=question, mode=mode, ti=table_info, ms=str(db_ms), sql=sql, top_k=str(TOP_K_DEFAULT))
out = llm.invoke(msg).content
out = get_llm().invoke(msg).content
return clean_sql_query(str(out))

# ------------------------------------------------------------
Expand Down Expand Up @@ -1023,7 +1085,9 @@ def _cache_set(key: str, value: Dict[str, Any]) -> None:

def build_chain_for_db(db_: SQLDatabase, dialect: str) -> Any:
final_prompt = make_final_prompt(dialect)
generate_query = create_sql_query_chain(llm, db_, final_prompt)
model = get_llm()
generate_query = create_sql_query_chain(model, db_, final_prompt)
rephrase_answer = answer_prompt | model | StrOutputParser()

def _run_exec(inputs: dict) -> dict:
sql = inputs.get("query", "")
Expand Down Expand Up @@ -1193,12 +1257,23 @@ def get_or_build_chain(db_url: str, schema_csv_text: Optional[str]) -> Dict[str,
# Default (demo) DB from env
# ------------------------------------------------------------
DEMO_DATABASE_URL = os.getenv("DATABASE_URL", "")
if not DEMO_DATABASE_URL:
raise RuntimeError("DATABASE_URL is not set. Provide a demo DB URL (e.g., Supabase Postgres).")

_demo_ctx = get_or_build_chain(DEMO_DATABASE_URL, schema_csv_text=(Path(__file__).resolve().parent / "database_table_descriptions.csv").read_text(encoding="utf-8", errors="ignore") if (Path(__file__).resolve().parent / "database_table_descriptions.csv").exists() else None)
_demo_chain = _demo_ctx["chain"]
_demo_table_details = _demo_ctx["table_details"]
_demo_ctx: Optional[Dict[str, Any]] = None
_demo_chain = None
_demo_table_details = None

def _get_demo_ctx() -> Dict[str, Any]:
global _demo_ctx, _demo_chain, _demo_table_details
if _demo_ctx is not None:
return _demo_ctx
if not DEMO_DATABASE_URL:
raise RuntimeError("DATABASE_URL is not set. Provide a demo DB URL (e.g., Supabase Postgres).")
schema_csv_path = Path(__file__).resolve().parent / "database_table_descriptions.csv"
schema_text = schema_csv_path.read_text(encoding="utf-8", errors="ignore") if schema_csv_path.exists() else None
_demo_ctx = get_or_build_chain(DEMO_DATABASE_URL, schema_csv_text=schema_text)
_demo_chain = _demo_ctx["chain"]
_demo_table_details = _demo_ctx["table_details"]
return _demo_ctx

# ------------------------------------------------------------
# Public API used by code1.py
Expand All @@ -1225,11 +1300,12 @@ def get_schema_tables(
"warnings": ctx["warnings"],
}
# demo
ctx = _get_demo_ctx()
return {
"tables": _demo_ctx["tables"],
"schema_source": _demo_ctx["schema_source"],
"dialect": _demo_ctx["dialect"],
"host": _demo_ctx["host"],
"tables": ctx["tables"],
"schema_source": ctx["schema_source"],
"dialect": ctx["dialect"],
"host": ctx["host"],
"warnings": [],
}

Expand Down Expand Up @@ -1282,24 +1358,25 @@ def chain_code(q: str, m: List[Dict[str, str]], mode: str = "public", db_url_ove
out["dialect"] = ctx["dialect"]
out["host"] = ctx["host"]
else:
ctx = _get_demo_ctx()
cache_db_url = DEMO_DATABASE_URL
cache_schema_source = _demo_ctx.get("schema_source", "csv")
cache_dialect = _demo_ctx.get("dialect", "demo")
cache_schema_source = ctx.get("schema_source", "csv")
cache_dialect = ctx.get("dialect", "demo")
if mode == "public":
ck = _cache_key(cache_db_url, cache_schema_source, cache_dialect, mode, q_norm)
cached = _cache_get(ck)
if cached:
return cached

out = _demo_chain.invoke({
out = ctx["chain"].invoke({
"question": q,
"table_details": _demo_table_details,
"table_details": ctx["table_details"],
"mode": mode,
"top_k": top_k,
})
out["schema_source"] = _demo_ctx["schema_source"]
out["dialect"] = _demo_ctx["dialect"]
out["host"] = _demo_ctx["host"]
out["schema_source"] = ctx["schema_source"]
out["dialect"] = ctx["dialect"]
out["host"] = ctx["host"]

# Add insights + chart for SELECT results
if out.get("kind") == "SELECT":
Expand Down
Loading