Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ MYSQL_USER=docbrain
MYSQL_PASSWORD=password
MYSQL_DATABASE=docbrain

# CORS (comma-separated origins)
CORS_ORIGINS=http://localhost:5173,http://127.0.0.1:5173

# Rate Limiting
RATE_LIMIT_PER_MINUTE=60

# Email (SendGrid)
SENDGRID_API_KEY=your-sendgrid-api-key
FROM_EMAIL=noreply@yourdomain.com

# File Upload
MAX_FILE_SIZE_MB=10
UPLOAD_DIR=/data/uploads
65 changes: 65 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
name: CI

on:
push:
branches: [main, dev, develop]
pull_request:
branches: [main, dev, develop]

jobs:
lint:
name: Lint
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
cache: pip

- name: Install linting tools
run: pip install flake8 black isort

- name: Check formatting with black
run: black --check --diff app/ tests/

- name: Check import ordering with isort
run: isort --check-only --diff app/ tests/

- name: Lint with flake8
run: flake8 app/ tests/ --max-line-length 120 --ignore E501,W503,E402

test:
name: Test
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
cache: pip

- name: Install test dependencies
run: pip install -r requirements-test.txt

- name: Run tests with coverage
run: pytest tests/ -v --cov=app --cov-report=term-missing --cov-report=xml
env:
ENVIRONMENT: test
SECRET_KEY: ci-test-secret-key
SENDGRID_API_KEY: test
FROM_EMAIL: test@example.com
PINECONE_API_KEY: test
PINECONE_ENVIRONMENT: test
WHITELISTED_EMAILS: test@example.com

- name: Upload coverage report
uses: actions/upload-artifact@v4
if: always()
with:
name: coverage-report
path: coverage.xml
14 changes: 13 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,19 @@ worker:
sh ./restart_worker.sh

test:
pytest
pytest tests/ -v

test-cov:
pytest tests/ -v --cov=app --cov-report=term-missing

lint:
black --check --diff app/ tests/
isort --check-only --diff app/ tests/
flake8 app/ tests/ --max-line-length 120 --ignore E501,W503,E402

format:
black app/ tests/
isort app/ tests/

clean:
find . -type d -name "__pycache__" -exec rm -r {} +
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

# DocBrain - Self-Hosted RAG Framework

![Python Version](https://img.shields.io/badge/python-3.8%2B-blue)
[![CI](https://github.com/shivama205/DocBrain/actions/workflows/ci.yml/badge.svg?branch=dev)](https://github.com/shivama205/DocBrain/actions/workflows/ci.yml)
![Python Version](https://img.shields.io/badge/python-3.11%2B-blue)
![License](https://img.shields.io/badge/license-MIT-green)
![Security](https://img.shields.io/badge/security-self--hosted-brightgreen)

Expand Down
2 changes: 1 addition & 1 deletion app/api/endpoints/knowledge_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ async def get_shared_users(
@router.post("/{kb_id}/documents", response_model=DocumentResponse)
async def create_document(
kb_id: str = Path(..., description="Knowledge base ID"),
file: UploadFile = Annotated[..., File(..., description="Document to upload")],
file: UploadFile = File(..., description="Document to upload"),
current_user: UserResponse = Depends(get_current_user),
doc_service: DocumentService = Depends(get_document_service)
):
Expand Down
20 changes: 15 additions & 5 deletions app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ class Settings(BaseSettings):
# Security
SECRET_KEY: str = os.getenv("SECRET_KEY", "your-secret-key")
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8 # 8 days
ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "1440")) # 24 hours

# Email Settings
SENDGRID_API_KEY: str
FROM_EMAIL: EmailStr
SENDGRID_API_KEY: str = os.getenv("SENDGRID_API_KEY", "")
FROM_EMAIL: str = os.getenv("FROM_EMAIL", "noreply@example.com")

# Vector Store
PINECONE_API_KEY: str = os.getenv("PINECONE_API_KEY", "")
PINECONE_ENVIRONMENT: str
PINECONE_ENVIRONMENT: str = os.getenv("PINECONE_ENVIRONMENT", "")
PINECONE_INDEX_NAME: str = os.getenv("PINECONE_INDEX_NAME", "docbrain")
PINECONE_SUMMARY_INDEX_NAME: str = os.getenv("PINECONE_SUMMARY_INDEX_NAME", "summary")
PINECONE_QUESTIONS_INDEX_NAME: str = os.getenv("PINECONE_QUESTIONS_INDEX_NAME", "questions")
Expand All @@ -33,7 +33,7 @@ class Settings(BaseSettings):
REDIS_URL: str = "redis://localhost:6379/0"

# Test Emails
WHITELISTED_EMAILS: str
WHITELISTED_EMAILS: str = os.getenv("WHITELISTED_EMAILS", "")

# RAG
RAG_TOP_K: int = 3
Expand All @@ -44,6 +44,9 @@ class Settings(BaseSettings):
def WHITELISTED_EMAIL_LIST(self) -> List[str]:
return [email.strip() for email in self.WHITELISTED_EMAILS.split(",")]

# Rate Limiting
RATE_LIMIT_PER_MINUTE: int = int(os.getenv("RATE_LIMIT_PER_MINUTE", "60"))

# File Upload
MAX_FILE_SIZE_MB: int = 10
UPLOAD_DIR: str = "/data/uploads"
Expand All @@ -65,6 +68,13 @@ def DATABASE_URL(self) -> str:
CELERY_RESULT_BACKEND: str = os.getenv("CELERY_RESULT_BACKEND", "redis://localhost:6379/0")

# CORS
CORS_ORIGINS: str = os.getenv("CORS_ORIGINS", "http://localhost:5173,http://127.0.0.1:5173")

@property
def CORS_ORIGIN_LIST(self) -> List[str]:
"""Parse comma-separated CORS origins."""
return [origin.strip() for origin in self.CORS_ORIGINS.split(",") if origin.strip()]

BACKEND_CORS_ORIGINS: List[AnyHttpUrl] = []

# Storage
Expand Down
63 changes: 62 additions & 1 deletion app/core/middleware.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
import time
import logging
from collections import defaultdict

from fastapi import Request, HTTPException, status
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from typing import Dict, List, Callable, Optional

from app.core.permissions import Permission, get_permissions_for_role
from app.db.models.user import UserRole

logger = logging.getLogger(__name__)


class PermissionsMiddleware(BaseHTTPMiddleware):
"""
Expand Down Expand Up @@ -144,4 +151,58 @@ async def dispatch(self, request: Request, call_next: Callable):
"PUT": [Permission.MANAGE_SYSTEM],
"DELETE": [Permission.MANAGE_SYSTEM],
},
}
}


class RateLimitMiddleware(BaseHTTPMiddleware):
"""
Simple in-memory rate limiting middleware.

Limits requests per client IP using a sliding window approach.
For production deployments with multiple workers, consider using
a Redis-backed solution instead.
"""

def __init__(
self,
app,
requests_per_minute: int = 60,
exempt_paths: Optional[List[str]] = None,
):
super().__init__(app)
self.requests_per_minute = requests_per_minute
self.exempt_paths = exempt_paths or ["/health", "/docs", "/openapi.json", "/redoc"]
# {client_ip: [timestamp, ...]}
self._requests: Dict[str, List[float]] = defaultdict(list)

def _get_client_ip(self, request: Request) -> str:
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
return forwarded.split(",")[0].strip()
return request.client.host if request.client else "unknown"

def _cleanup(self, timestamps: List[float], now: float) -> List[float]:
"""Remove timestamps older than 60 seconds."""
cutoff = now - 60.0
return [t for t in timestamps if t > cutoff]

async def dispatch(self, request: Request, call_next: Callable):
if any(request.url.path.startswith(p) for p in self.exempt_paths):
return await call_next(request)

client_ip = self._get_client_ip(request)
now = time.time()

# Clean old entries and record this request
self._requests[client_ip] = self._cleanup(self._requests[client_ip], now)

if len(self._requests[client_ip]) >= self.requests_per_minute:
logger.warning(f"Rate limit exceeded for {client_ip}")
return JSONResponse(
status_code=429,
content={"detail": "Too many requests. Please try again later."},
headers={"Retry-After": "60"},
)

self._requests[client_ip].append(now)
return await call_next(request)
21 changes: 18 additions & 3 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from app.core.config import settings
from app.api.endpoints import auth, knowledge_bases, conversations, messages, users
from app.core.middleware import PermissionsMiddleware, DEFAULT_PATH_PERMISSIONS
from app.core.middleware import PermissionsMiddleware, RateLimitMiddleware, DEFAULT_PATH_PERMISSIONS

app = FastAPI(
title=settings.APP_NAME,
Expand All @@ -13,12 +13,18 @@
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:5173", "http://127.0.0.1:5173", "*"], # Explicitly allow frontend origin
allow_origins=settings.CORS_ORIGIN_LIST,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

# Add rate limiting middleware
app.add_middleware(
RateLimitMiddleware,
requests_per_minute=settings.RATE_LIMIT_PER_MINUTE,
)

# Add Permissions middleware
app.add_middleware(
PermissionsMiddleware,
Expand All @@ -34,4 +40,13 @@

@app.get("/")
async def root():
return {"message": "Welcome to DocBrain API"}
return {"message": "Welcome to DocBrain API"}

@app.get("/health")
async def health():
"""Health check endpoint for monitoring and orchestration."""
return {
"status": "healthy",
"service": settings.APP_NAME,
"version": app.version,
}
3 changes: 2 additions & 1 deletion app/repositories/storage_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ async def insert_csv(db: Session, table_name: str, create_table_query: str, colu

# insert the data one by one
for row in data:
INSERT_ROW_QUERY = f"INSERT INTO {table_name} ({', '.join(columns)}) VALUES ({', '.join([f"'{str(cell)}'" for cell in row])})"
values = ', '.join(["'{}'".format(str(cell)) for cell in row])
INSERT_ROW_QUERY = f"INSERT INTO {table_name} ({', '.join(columns)}) VALUES ({values})"
logger.info(f"Insert Row Query: {INSERT_ROW_QUERY}")
db.execute(text(INSERT_ROW_QUERY))
db.commit()
Expand Down
3 changes: 2 additions & 1 deletion app/services/rag/chunker/chunker_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def create_chunker(document_type: DocumentType) -> Chunker:
Chunker instance
"""
try:
# TODO: Implement chunker factory based on document type
# MultiLevelChunker works well across all document types.
# Extend here with type-specific chunkers if needed (e.g., CSV row-based).
return MultiLevelChunker()
except Exception as e:
logger.error(f"Failed to create chunker: {e}", exc_info=True)
Expand Down
2 changes: 1 addition & 1 deletion app/services/rag_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ async def retrieve_from_storage(
# Create retriever using the provided knowledge_base_id
retriever = RetrieverFactory.create_retriever(knowledge_base_id)

# TODO: Add Text 2 SQL to convert query to SQL and remove chunks
# For SQL-based retrieval, use the TAG service via QueryRouter instead.
# Retrieve chunks
chunks = await retriever.search(query, top_k, similarity_threshold)
return chunks
Expand Down
29 changes: 29 additions & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Minimal dependencies for running the test suite in CI.
# The full requirements.txt includes heavy ML libraries (torch, transformers,
# sentence-transformers) that are not needed for unit tests.

# Core framework
fastapi>=0.115.0
uvicorn>=0.34.0
starlette>=0.45.0

# Database
SQLAlchemy>=2.0.25
PyMySQL>=1.1.0

# Auth & security
python-jose[cryptography]>=3.3.0
passlib[bcrypt]>=1.7.4

# Validation
pydantic>=2.5.3
pydantic-settings>=2.1.0
email-validator>=2.0.0

# Config
python-dotenv>=1.0.0

# Testing
pytest>=8.0.0
pytest-asyncio>=0.25.0
pytest-cov>=6.0.0
Empty file added tests/__init__.py
Empty file.
Loading
Loading