diff --git a/README.md b/README.md index f65f67a..bdc7306 100644 --- a/README.md +++ b/README.md @@ -69,7 +69,7 @@ flowchart LR U[User] --> FE[Frontend - React Vite] FE --> API[Flask API] API --> SCHEMA[Schema - CSV or introspection] - API --> LLM[Gemini - LangChain] + API --> LLM[Gemini/OpenRouter - LangChain] API --> DB[(SQL database)] API --> R[(Redis - optional)] API --> M[Metrics - Prometheus] @@ -159,7 +159,7 @@ Base URL (local): `http://127.0.0.1:5000` - Deployed on Vercel **Backend** -- Flask + SQLAlchemy + LangChain + Gemini +- Flask + SQLAlchemy + LangChain + Gemini/OpenRouter - Rate limiting + structured logging - Optional: Redis + RQ + Prometheus @@ -190,9 +190,19 @@ Create `.env` in the repo root (do **not** commit it): ```env # --- LLM --- +LLM_PROVIDER=gemini # gemini | openrouter GOOGLE_API_KEY=YOUR_GOOGLE_API_KEY GEMINI_MODEL=gemini-2.5-flash +# --- OpenRouter (optional) --- +OPENROUTER_API_KEY=YOUR_OPENROUTER_API_KEY +OPENROUTER_MODEL=openai/gpt-4o-mini +OPENROUTER_EMBEDDINGS_MODEL=text-embedding-3-small +OPENROUTER_BASE_URL=https://openrouter.ai/api/v1 +OPENROUTER_SITE_URL= +OPENROUTER_APP_NAME=AskDB +EMBEDDINGS_PROVIDER=gemini # gemini | openrouter + # --- Demo DB (default) --- DATABASE_URL=postgresql+psycopg2://user:pass@host:5432/postgres?sslmode=require diff --git a/SETUP.md b/SETUP.md index 37c8470..96fa5ea 100644 --- a/SETUP.md +++ b/SETUP.md @@ -55,8 +55,21 @@ python -c "import langchain; import flask; print('Installation successful!')" Create a `.env` file in the project root: ```bash +# LLM Provider +LLM_PROVIDER=gemini # gemini | openrouter + # Google AI Configuration GOOGLE_API_KEY=your_google_api_key_here +GEMINI_MODEL=gemini-2.5-flash + +# OpenRouter Configuration (optional) +OPENROUTER_API_KEY=your_openrouter_api_key_here +OPENROUTER_MODEL=openai/gpt-4o-mini +OPENROUTER_EMBEDDINGS_MODEL=text-embedding-3-small +OPENROUTER_BASE_URL=https://openrouter.ai/api/v1 +OPENROUTER_SITE_URL= +OPENROUTER_APP_NAME=AskDB +EMBEDDINGS_PROVIDER=gemini # gemini | openrouter # LangChain Configuration LANGCHAIN_TRACING_V2=true diff --git a/TECHNICAL.md b/TECHNICAL.md index 7b12885..2a887e5 100644 --- a/TECHNICAL.md +++ b/TECHNICAL.md @@ -59,8 +59,9 @@ AskDB is built on LangChain, leveraging its powerful framework for building AI a #### **LLM Integration** ```python -# Google Gemini integration via LangChain -llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro", temperature=0) +# Provider-configurable LLM via LangChain (Gemini/OpenRouter) +LLM_PROVIDER=gemini # or openrouter +llm = get_llm() ``` #### **Prompt Management** @@ -181,11 +182,9 @@ Final Answer ### Vector Embeddings ```python -# Uses Google's embedding model for semantic similarity -embeddings = GoogleGenerativeAIEmbeddings( - model="models/embedding-001", - task_type="retrieval_query" -) +# Provider-configurable embeddings for semantic similarity +EMBEDDINGS_PROVIDER=gemini # or openrouter +embeddings = get_embeddings() ``` ### Structured Output diff --git a/requirements.txt b/requirements.txt index 99a0285..5ee72d4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,6 +8,8 @@ pandas==2.2.2 pydantic==2.7.4 langchain==0.2.6 langchain-community==0.2.6 +langchain-openai==0.1.8 +openai==1.30.1 google-genai==0.5.0 proto-plus==1.24.0 protobuf==4.25.3 diff --git a/untitled0.py b/untitled0.py index fddc9d8..5ca5700 100644 --- a/untitled0.py +++ b/untitled0.py @@ -35,6 +35,7 @@ except Exception: # pragma: no cover from langchain_classic.chains.sql_database.query import create_sql_query_chain from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings +from langchain_openai import ChatOpenAI, OpenAIEmbeddings from langchain_chroma import Chroma from langchain_core.example_selectors import SemanticSimilarityExampleSelector from pydantic import BaseModel, Field, ValidationError @@ -42,10 +43,69 @@ load_dotenv() # ------------------------------------------------------------ -# LLM +# LLM / Embeddings # ------------------------------------------------------------ +LLM_PROVIDER = os.getenv("LLM_PROVIDER", "gemini").lower() +EMBEDDINGS_PROVIDER = os.getenv("EMBEDDINGS_PROVIDER", LLM_PROVIDER).lower() + GEMINI_MODEL = os.getenv("GEMINI_MODEL", "gemini-2.5-flash") -llm = ChatGoogleGenerativeAI(model=GEMINI_MODEL, temperature=0) + +OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY", "") +OPENROUTER_MODEL = os.getenv("OPENROUTER_MODEL", "openai/gpt-4o-mini") +OPENROUTER_EMBEDDINGS_MODEL = os.getenv("OPENROUTER_EMBEDDINGS_MODEL", "text-embedding-3-small") +OPENROUTER_BASE_URL = os.getenv("OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1") +OPENROUTER_SITE_URL = os.getenv("OPENROUTER_SITE_URL", "") +OPENROUTER_APP_NAME = os.getenv("OPENROUTER_APP_NAME", "AskDB") + +_llm = None +_embeddings = None + +def _openrouter_headers() -> Optional[Dict[str, str]]: + headers: Dict[str, str] = {} + if OPENROUTER_SITE_URL: + headers["HTTP-Referer"] = OPENROUTER_SITE_URL + if OPENROUTER_APP_NAME: + headers["X-Title"] = OPENROUTER_APP_NAME + return headers or None + +def get_llm(): + global _llm + if _llm is not None: + return _llm + if LLM_PROVIDER == "openrouter": + if not OPENROUTER_API_KEY: + raise ValueError("OPENROUTER_API_KEY is required when LLM_PROVIDER=openrouter") + _llm = ChatOpenAI( + model=OPENROUTER_MODEL, + temperature=0, + openai_api_key=OPENROUTER_API_KEY, + base_url=OPENROUTER_BASE_URL, + default_headers=_openrouter_headers(), + ) + return _llm + if LLM_PROVIDER == "gemini": + _llm = ChatGoogleGenerativeAI(model=GEMINI_MODEL, temperature=0) + return _llm + raise ValueError(f"Unsupported LLM_PROVIDER: {LLM_PROVIDER}") + +def get_embeddings(): + global _embeddings + if _embeddings is not None: + return _embeddings + if EMBEDDINGS_PROVIDER == "openrouter": + if not OPENROUTER_API_KEY: + raise ValueError("OPENROUTER_API_KEY is required when EMBEDDINGS_PROVIDER=openrouter") + _embeddings = OpenAIEmbeddings( + model=OPENROUTER_EMBEDDINGS_MODEL, + openai_api_key=OPENROUTER_API_KEY, + base_url=OPENROUTER_BASE_URL, + default_headers=_openrouter_headers(), + ) + return _embeddings + if EMBEDDINGS_PROVIDER == "gemini": + _embeddings = GoogleGenerativeAIEmbeddings(model="models/gemini-embedding-001", task_type="retrieval_query") + return _embeddings + raise ValueError(f"Unsupported EMBEDDINGS_PROVIDER: {EMBEDDINGS_PROVIDER}") # ------------------------------------------------------------ # Defaults / knobs @@ -609,7 +669,7 @@ def build_few_shot_prompt() -> FewShotChatMessagePromptTemplate: return static_prompt try: - embeddings = GoogleGenerativeAIEmbeddings(model="models/gemini-embedding-001", task_type="retrieval_query") + embeddings = get_embeddings() selector = SemanticSimilarityExampleSelector.from_examples( examples=EXAMPLES, embeddings=embeddings, @@ -706,7 +766,7 @@ def select_tables_for_question(question: str, table_details: str) -> List[str]: # LLM selection (best-effort) try: - raw = (_TABLE_SELECT_PROMPT | llm | StrOutputParser()).invoke( + raw = (_TABLE_SELECT_PROMPT | get_llm() | StrOutputParser()).invoke( {"question": question, "table_details": table_details} ) picked = _parse_table_list(raw, allowed) @@ -786,7 +846,10 @@ def _schema_rag_select(question: str, table_details: str, db_cache_key: str) -> # Build / reuse vectorstore vs_key = f"schema_vs::{db_cache_key}" if vs_key not in _schema_vs_cache: - embeddings = GoogleGenerativeAIEmbeddings(model="models/gemini-embedding-001", task_type="retrieval_query") + try: + embeddings = get_embeddings() + except Exception: + return table_details persist_dir = ".chroma_schema" if SCHEMA_RAG_PERSIST else None vs = Chroma( collection_name=f"askdb_schema_{hashlib.md5(db_cache_key.encode()).hexdigest()[:12]}", @@ -824,7 +887,6 @@ def _schema_rag_select(question: str, table_details: str, db_cache_key: str) -> Write the answer clearly and concisely in business-friendly language.""" ) -rephrase_answer = answer_prompt | llm | StrOutputParser() class ChartSpec(BaseModel): summary: str = Field(description="1-2 sentence executive summary of the results") @@ -912,7 +974,7 @@ def is_date_like(v) -> bool: ) try: msg = prompt.format_messages(q=question, cols=", ".join(cols), rows=json.dumps(sample, default=str)) - raw = str(llm.invoke(msg).content).strip() + raw = str(get_llm().invoke(msg).content).strip() raw = re.sub(r"^```(?:json)?\s*|\s*```$", "", raw, flags=re.DOTALL).strip() data = json.loads(raw) spec = ChartSpec(**data) @@ -938,7 +1000,7 @@ def _repair_sql(question: str, table_info: str, dialect: str, mode: str, bad_sql ] ) msg = prompt.format_messages(q=question, mode=mode, ti=table_info, sql=bad_sql, err=error_msg) - fixed = llm.invoke(msg).content + fixed = get_llm().invoke(msg).content return clean_sql_query(str(fixed)) def _rewrite_sql_for_budget(question: str, table_info: str, dialect: str, mode: str, sql: str, explain: Dict[str, Any]) -> str: @@ -958,7 +1020,7 @@ def _rewrite_sql_for_budget(question: str, table_info: str, dialect: str, mode: ] ) msg = prompt.format_messages(q=question, mode=mode, ti=table_info, sql=sql, ex=json.dumps(explain)) - out = llm.invoke(msg).content + out = get_llm().invoke(msg).content return clean_sql_query(str(out)) def _rewrite_sql_for_speed(question: str, table_info: str, dialect: str, mode: str, sql: str, db_ms: int) -> str: @@ -978,7 +1040,7 @@ def _rewrite_sql_for_speed(question: str, table_info: str, dialect: str, mode: s ] ) msg = prompt.format_messages(q=question, mode=mode, ti=table_info, ms=str(db_ms), sql=sql, top_k=str(TOP_K_DEFAULT)) - out = llm.invoke(msg).content + out = get_llm().invoke(msg).content return clean_sql_query(str(out)) # ------------------------------------------------------------ @@ -1023,7 +1085,9 @@ def _cache_set(key: str, value: Dict[str, Any]) -> None: def build_chain_for_db(db_: SQLDatabase, dialect: str) -> Any: final_prompt = make_final_prompt(dialect) - generate_query = create_sql_query_chain(llm, db_, final_prompt) + model = get_llm() + generate_query = create_sql_query_chain(model, db_, final_prompt) + rephrase_answer = answer_prompt | model | StrOutputParser() def _run_exec(inputs: dict) -> dict: sql = inputs.get("query", "") @@ -1193,12 +1257,23 @@ def get_or_build_chain(db_url: str, schema_csv_text: Optional[str]) -> Dict[str, # Default (demo) DB from env # ------------------------------------------------------------ DEMO_DATABASE_URL = os.getenv("DATABASE_URL", "") -if not DEMO_DATABASE_URL: - raise RuntimeError("DATABASE_URL is not set. Provide a demo DB URL (e.g., Supabase Postgres).") -_demo_ctx = get_or_build_chain(DEMO_DATABASE_URL, schema_csv_text=(Path(__file__).resolve().parent / "database_table_descriptions.csv").read_text(encoding="utf-8", errors="ignore") if (Path(__file__).resolve().parent / "database_table_descriptions.csv").exists() else None) -_demo_chain = _demo_ctx["chain"] -_demo_table_details = _demo_ctx["table_details"] +_demo_ctx: Optional[Dict[str, Any]] = None +_demo_chain = None +_demo_table_details = None + +def _get_demo_ctx() -> Dict[str, Any]: + global _demo_ctx, _demo_chain, _demo_table_details + if _demo_ctx is not None: + return _demo_ctx + if not DEMO_DATABASE_URL: + raise RuntimeError("DATABASE_URL is not set. Provide a demo DB URL (e.g., Supabase Postgres).") + schema_csv_path = Path(__file__).resolve().parent / "database_table_descriptions.csv" + schema_text = schema_csv_path.read_text(encoding="utf-8", errors="ignore") if schema_csv_path.exists() else None + _demo_ctx = get_or_build_chain(DEMO_DATABASE_URL, schema_csv_text=schema_text) + _demo_chain = _demo_ctx["chain"] + _demo_table_details = _demo_ctx["table_details"] + return _demo_ctx # ------------------------------------------------------------ # Public API used by code1.py @@ -1225,11 +1300,12 @@ def get_schema_tables( "warnings": ctx["warnings"], } # demo + ctx = _get_demo_ctx() return { - "tables": _demo_ctx["tables"], - "schema_source": _demo_ctx["schema_source"], - "dialect": _demo_ctx["dialect"], - "host": _demo_ctx["host"], + "tables": ctx["tables"], + "schema_source": ctx["schema_source"], + "dialect": ctx["dialect"], + "host": ctx["host"], "warnings": [], } @@ -1282,24 +1358,25 @@ def chain_code(q: str, m: List[Dict[str, str]], mode: str = "public", db_url_ove out["dialect"] = ctx["dialect"] out["host"] = ctx["host"] else: + ctx = _get_demo_ctx() cache_db_url = DEMO_DATABASE_URL - cache_schema_source = _demo_ctx.get("schema_source", "csv") - cache_dialect = _demo_ctx.get("dialect", "demo") + cache_schema_source = ctx.get("schema_source", "csv") + cache_dialect = ctx.get("dialect", "demo") if mode == "public": ck = _cache_key(cache_db_url, cache_schema_source, cache_dialect, mode, q_norm) cached = _cache_get(ck) if cached: return cached - out = _demo_chain.invoke({ + out = ctx["chain"].invoke({ "question": q, - "table_details": _demo_table_details, + "table_details": ctx["table_details"], "mode": mode, "top_k": top_k, }) - out["schema_source"] = _demo_ctx["schema_source"] - out["dialect"] = _demo_ctx["dialect"] - out["host"] = _demo_ctx["host"] + out["schema_source"] = ctx["schema_source"] + out["dialect"] = ctx["dialect"] + out["host"] = ctx["host"] # Add insights + chart for SELECT results if out.get("kind") == "SELECT":