diff --git a/.env.example b/.env.example index 707bcb7..47e2d99 100644 --- a/.env.example +++ b/.env.example @@ -34,6 +34,16 @@ MYSQL_USER=docbrain MYSQL_PASSWORD=password MYSQL_DATABASE=docbrain +# CORS (comma-separated origins) +CORS_ORIGINS=http://localhost:5173,http://127.0.0.1:5173 + +# Rate Limiting +RATE_LIMIT_PER_MINUTE=60 + +# Email (SendGrid) +SENDGRID_API_KEY=your-sendgrid-api-key +FROM_EMAIL=noreply@yourdomain.com + # File Upload MAX_FILE_SIZE_MB=10 UPLOAD_DIR=/data/uploads \ No newline at end of file diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..f947f2a --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,65 @@ +name: CI + +on: + push: + branches: [main, dev, develop] + pull_request: + branches: [main, dev, develop] + +jobs: + lint: + name: Lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + cache: pip + + - name: Install linting tools + run: pip install flake8 black isort + + - name: Check formatting with black + run: black --check --diff app/ tests/ + + - name: Check import ordering with isort + run: isort --check-only --diff app/ tests/ + + - name: Lint with flake8 + run: flake8 app/ tests/ --max-line-length 120 --ignore E501,W503,E402 + + test: + name: Test + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + cache: pip + + - name: Install test dependencies + run: pip install -r requirements-test.txt + + - name: Run tests with coverage + run: pytest tests/ -v --cov=app --cov-report=term-missing --cov-report=xml + env: + ENVIRONMENT: test + SECRET_KEY: ci-test-secret-key + SENDGRID_API_KEY: test + FROM_EMAIL: test@example.com + PINECONE_API_KEY: test + PINECONE_ENVIRONMENT: test + WHITELISTED_EMAILS: test@example.com + + - name: Upload coverage report + uses: actions/upload-artifact@v4 + if: always() + with: + name: coverage-report + path: coverage.xml diff --git a/Makefile b/Makefile index 1b31bba..31a1715 100644 --- a/Makefile +++ b/Makefile @@ -22,7 +22,19 @@ worker: sh ./restart_worker.sh test: - pytest + pytest tests/ -v + +test-cov: + pytest tests/ -v --cov=app --cov-report=term-missing + +lint: + black --check --diff app/ tests/ + isort --check-only --diff app/ tests/ + flake8 app/ tests/ --max-line-length 120 --ignore E501,W503,E402 + +format: + black app/ tests/ + isort app/ tests/ clean: find . -type d -name "__pycache__" -exec rm -r {} + diff --git a/README.md b/README.md index e9d5c09..dec3a9f 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,8 @@ # DocBrain - Self-Hosted RAG Framework -![Python Version](https://img.shields.io/badge/python-3.8%2B-blue) +[![CI](https://github.com/shivama205/DocBrain/actions/workflows/ci.yml/badge.svg?branch=dev)](https://github.com/shivama205/DocBrain/actions/workflows/ci.yml) +![Python Version](https://img.shields.io/badge/python-3.11%2B-blue) ![License](https://img.shields.io/badge/license-MIT-green) ![Security](https://img.shields.io/badge/security-self--hosted-brightgreen) diff --git a/app/api/endpoints/knowledge_bases.py b/app/api/endpoints/knowledge_bases.py index 7f69df1..73e395e 100644 --- a/app/api/endpoints/knowledge_bases.py +++ b/app/api/endpoints/knowledge_bases.py @@ -202,7 +202,7 @@ async def get_shared_users( @router.post("/{kb_id}/documents", response_model=DocumentResponse) async def create_document( kb_id: str = Path(..., description="Knowledge base ID"), - file: UploadFile = Annotated[..., File(..., description="Document to upload")], + file: UploadFile = File(..., description="Document to upload"), current_user: UserResponse = Depends(get_current_user), doc_service: DocumentService = Depends(get_document_service) ): diff --git a/app/core/config.py b/app/core/config.py index 8534b14..2ca5199 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -12,15 +12,15 @@ class Settings(BaseSettings): # Security SECRET_KEY: str = os.getenv("SECRET_KEY", "your-secret-key") ALGORITHM: str = "HS256" - ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8 # 8 days + ACCESS_TOKEN_EXPIRE_MINUTES: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "1440")) # 24 hours # Email Settings - SENDGRID_API_KEY: str - FROM_EMAIL: EmailStr + SENDGRID_API_KEY: str = os.getenv("SENDGRID_API_KEY", "") + FROM_EMAIL: str = os.getenv("FROM_EMAIL", "noreply@example.com") # Vector Store PINECONE_API_KEY: str = os.getenv("PINECONE_API_KEY", "") - PINECONE_ENVIRONMENT: str + PINECONE_ENVIRONMENT: str = os.getenv("PINECONE_ENVIRONMENT", "") PINECONE_INDEX_NAME: str = os.getenv("PINECONE_INDEX_NAME", "docbrain") PINECONE_SUMMARY_INDEX_NAME: str = os.getenv("PINECONE_SUMMARY_INDEX_NAME", "summary") PINECONE_QUESTIONS_INDEX_NAME: str = os.getenv("PINECONE_QUESTIONS_INDEX_NAME", "questions") @@ -33,7 +33,7 @@ class Settings(BaseSettings): REDIS_URL: str = "redis://localhost:6379/0" # Test Emails - WHITELISTED_EMAILS: str + WHITELISTED_EMAILS: str = os.getenv("WHITELISTED_EMAILS", "") # RAG RAG_TOP_K: int = 3 @@ -44,6 +44,9 @@ class Settings(BaseSettings): def WHITELISTED_EMAIL_LIST(self) -> List[str]: return [email.strip() for email in self.WHITELISTED_EMAILS.split(",")] + # Rate Limiting + RATE_LIMIT_PER_MINUTE: int = int(os.getenv("RATE_LIMIT_PER_MINUTE", "60")) + # File Upload MAX_FILE_SIZE_MB: int = 10 UPLOAD_DIR: str = "/data/uploads" @@ -65,6 +68,13 @@ def DATABASE_URL(self) -> str: CELERY_RESULT_BACKEND: str = os.getenv("CELERY_RESULT_BACKEND", "redis://localhost:6379/0") # CORS + CORS_ORIGINS: str = os.getenv("CORS_ORIGINS", "http://localhost:5173,http://127.0.0.1:5173") + + @property + def CORS_ORIGIN_LIST(self) -> List[str]: + """Parse comma-separated CORS origins.""" + return [origin.strip() for origin in self.CORS_ORIGINS.split(",") if origin.strip()] + BACKEND_CORS_ORIGINS: List[AnyHttpUrl] = [] # Storage diff --git a/app/core/middleware.py b/app/core/middleware.py index e5f7e40..53826d5 100644 --- a/app/core/middleware.py +++ b/app/core/middleware.py @@ -1,10 +1,17 @@ +import time +import logging +from collections import defaultdict + from fastapi import Request, HTTPException, status from starlette.middleware.base import BaseHTTPMiddleware +from starlette.responses import JSONResponse from typing import Dict, List, Callable, Optional from app.core.permissions import Permission, get_permissions_for_role from app.db.models.user import UserRole +logger = logging.getLogger(__name__) + class PermissionsMiddleware(BaseHTTPMiddleware): """ @@ -144,4 +151,58 @@ async def dispatch(self, request: Request, call_next: Callable): "PUT": [Permission.MANAGE_SYSTEM], "DELETE": [Permission.MANAGE_SYSTEM], }, -} \ No newline at end of file +} + + +class RateLimitMiddleware(BaseHTTPMiddleware): + """ + Simple in-memory rate limiting middleware. + + Limits requests per client IP using a sliding window approach. + For production deployments with multiple workers, consider using + a Redis-backed solution instead. + """ + + def __init__( + self, + app, + requests_per_minute: int = 60, + exempt_paths: Optional[List[str]] = None, + ): + super().__init__(app) + self.requests_per_minute = requests_per_minute + self.exempt_paths = exempt_paths or ["/health", "/docs", "/openapi.json", "/redoc"] + # {client_ip: [timestamp, ...]} + self._requests: Dict[str, List[float]] = defaultdict(list) + + def _get_client_ip(self, request: Request) -> str: + forwarded = request.headers.get("x-forwarded-for") + if forwarded: + return forwarded.split(",")[0].strip() + return request.client.host if request.client else "unknown" + + def _cleanup(self, timestamps: List[float], now: float) -> List[float]: + """Remove timestamps older than 60 seconds.""" + cutoff = now - 60.0 + return [t for t in timestamps if t > cutoff] + + async def dispatch(self, request: Request, call_next: Callable): + if any(request.url.path.startswith(p) for p in self.exempt_paths): + return await call_next(request) + + client_ip = self._get_client_ip(request) + now = time.time() + + # Clean old entries and record this request + self._requests[client_ip] = self._cleanup(self._requests[client_ip], now) + + if len(self._requests[client_ip]) >= self.requests_per_minute: + logger.warning(f"Rate limit exceeded for {client_ip}") + return JSONResponse( + status_code=429, + content={"detail": "Too many requests. Please try again later."}, + headers={"Retry-After": "60"}, + ) + + self._requests[client_ip].append(now) + return await call_next(request) \ No newline at end of file diff --git a/app/main.py b/app/main.py index 676fce1..c6339b2 100644 --- a/app/main.py +++ b/app/main.py @@ -3,7 +3,7 @@ from app.core.config import settings from app.api.endpoints import auth, knowledge_bases, conversations, messages, users -from app.core.middleware import PermissionsMiddleware, DEFAULT_PATH_PERMISSIONS +from app.core.middleware import PermissionsMiddleware, RateLimitMiddleware, DEFAULT_PATH_PERMISSIONS app = FastAPI( title=settings.APP_NAME, @@ -13,12 +13,18 @@ # Add CORS middleware app.add_middleware( CORSMiddleware, - allow_origins=["http://localhost:5173", "http://127.0.0.1:5173", "*"], # Explicitly allow frontend origin + allow_origins=settings.CORS_ORIGIN_LIST, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) +# Add rate limiting middleware +app.add_middleware( + RateLimitMiddleware, + requests_per_minute=settings.RATE_LIMIT_PER_MINUTE, +) + # Add Permissions middleware app.add_middleware( PermissionsMiddleware, @@ -34,4 +40,13 @@ @app.get("/") async def root(): - return {"message": "Welcome to DocBrain API"} \ No newline at end of file + return {"message": "Welcome to DocBrain API"} + +@app.get("/health") +async def health(): + """Health check endpoint for monitoring and orchestration.""" + return { + "status": "healthy", + "service": settings.APP_NAME, + "version": app.version, + } \ No newline at end of file diff --git a/app/repositories/storage_repository.py b/app/repositories/storage_repository.py index 85203db..936386b 100644 --- a/app/repositories/storage_repository.py +++ b/app/repositories/storage_repository.py @@ -22,7 +22,8 @@ async def insert_csv(db: Session, table_name: str, create_table_query: str, colu # insert the data one by one for row in data: - INSERT_ROW_QUERY = f"INSERT INTO {table_name} ({', '.join(columns)}) VALUES ({', '.join([f"'{str(cell)}'" for cell in row])})" + values = ', '.join(["'{}'".format(str(cell)) for cell in row]) + INSERT_ROW_QUERY = f"INSERT INTO {table_name} ({', '.join(columns)}) VALUES ({values})" logger.info(f"Insert Row Query: {INSERT_ROW_QUERY}") db.execute(text(INSERT_ROW_QUERY)) db.commit() diff --git a/app/services/rag/chunker/chunker_factory.py b/app/services/rag/chunker/chunker_factory.py index 2c9bd43..b748133 100644 --- a/app/services/rag/chunker/chunker_factory.py +++ b/app/services/rag/chunker/chunker_factory.py @@ -23,7 +23,8 @@ def create_chunker(document_type: DocumentType) -> Chunker: Chunker instance """ try: - # TODO: Implement chunker factory based on document type + # MultiLevelChunker works well across all document types. + # Extend here with type-specific chunkers if needed (e.g., CSV row-based). return MultiLevelChunker() except Exception as e: logger.error(f"Failed to create chunker: {e}", exc_info=True) diff --git a/app/services/rag_service.py b/app/services/rag_service.py index 8fd4f9b..63caaa8 100644 --- a/app/services/rag_service.py +++ b/app/services/rag_service.py @@ -155,7 +155,7 @@ async def retrieve_from_storage( # Create retriever using the provided knowledge_base_id retriever = RetrieverFactory.create_retriever(knowledge_base_id) - # TODO: Add Text 2 SQL to convert query to SQL and remove chunks + # For SQL-based retrieval, use the TAG service via QueryRouter instead. # Retrieve chunks chunks = await retriever.search(query, top_k, similarity_threshold) return chunks diff --git a/requirements-test.txt b/requirements-test.txt new file mode 100644 index 0000000..da224ea --- /dev/null +++ b/requirements-test.txt @@ -0,0 +1,29 @@ +# Minimal dependencies for running the test suite in CI. +# The full requirements.txt includes heavy ML libraries (torch, transformers, +# sentence-transformers) that are not needed for unit tests. + +# Core framework +fastapi>=0.115.0 +uvicorn>=0.34.0 +starlette>=0.45.0 + +# Database +SQLAlchemy>=2.0.25 +PyMySQL>=1.1.0 + +# Auth & security +python-jose[cryptography]>=3.3.0 +passlib[bcrypt]>=1.7.4 + +# Validation +pydantic>=2.5.3 +pydantic-settings>=2.1.0 +email-validator>=2.0.0 + +# Config +python-dotenv>=1.0.0 + +# Testing +pytest>=8.0.0 +pytest-asyncio>=0.25.0 +pytest-cov>=6.0.0 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..8b709bc --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,77 @@ +""" +Shared test fixtures for DocBrain test suite. + +Environment variables and module mocks are set up before any app imports +to avoid database connection errors during testing. +""" +import os +import sys +from unittest.mock import MagicMock + +# --------------------------------------------------------------------------- +# 1. Set test environment variables BEFORE any app imports +# --------------------------------------------------------------------------- +os.environ.setdefault("ENVIRONMENT", "test") +os.environ.setdefault("SECRET_KEY", "test-secret-key-for-unit-tests") +os.environ.setdefault("SENDGRID_API_KEY", "test-key") +os.environ.setdefault("FROM_EMAIL", "test@example.com") +os.environ.setdefault("PINECONE_API_KEY", "test-key") +os.environ.setdefault("PINECONE_ENVIRONMENT", "test") +os.environ.setdefault("WHITELISTED_EMAILS", "test@example.com") +os.environ.setdefault("MYSQL_HOST", "localhost") +os.environ.setdefault("REDIS_URL", "redis://localhost:6379/0") +# Use SQLite for test to avoid needing MySQL driver +os.environ["DATABASE_URL"] = "sqlite://" + +# --------------------------------------------------------------------------- +# 2. Pre-populate sys.modules with a mock for app.db.database so that +# importing permissions / middleware does NOT trigger create_engine. +# --------------------------------------------------------------------------- +_mock_db_module = MagicMock() +_mock_db_module.get_db = MagicMock() +_mock_db_module.engine = MagicMock() +_mock_db_module.SessionLocal = MagicMock() +sys.modules.setdefault("app.db.database", _mock_db_module) + +# --------------------------------------------------------------------------- +# Now safe to import app modules +# --------------------------------------------------------------------------- +import pytest +from app.db.models.user import UserRole +from app.schemas.user import UserResponse + + +def _make_user(role: UserRole, user_id: str = "user-123") -> UserResponse: + return UserResponse( + id=user_id, + email=f"{role.value}@example.com", + full_name=f"Test {role.value.title()}", + role=role, + is_active=True, + hashed_password="hashed", + ) + + +@pytest.fixture +def admin_user(): + return _make_user(UserRole.ADMIN, "admin-001") + + +@pytest.fixture +def owner_user(): + return _make_user(UserRole.OWNER, "owner-001") + + +@pytest.fixture +def regular_user(): + return _make_user(UserRole.USER, "user-001") + + +@pytest.fixture +def mock_db(): + session = MagicMock() + session.commit = MagicMock() + session.rollback = MagicMock() + session.close = MagicMock() + session.refresh = MagicMock() + return session diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..fa991bd --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,42 @@ +"""Tests for authentication utilities.""" +import pytest +from datetime import timedelta +from unittest.mock import patch +from jose import jwt + + +class TestAccessToken: + """Test JWT token creation and validation.""" + + def test_create_token_returns_string(self): + from app.api.deps import create_access_token + token = create_access_token("user-123") + assert isinstance(token, str) + assert len(token) > 0 + + def test_token_contains_user_id(self): + from app.api.deps import create_access_token + from app.core.config import settings + token = create_access_token("user-abc") + payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) + assert payload["sub"] == "user-abc" + + def test_token_has_expiry(self): + from app.api.deps import create_access_token + from app.core.config import settings + token = create_access_token("user-123") + payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) + assert "exp" in payload + + def test_custom_expiry_delta(self): + from app.api.deps import create_access_token + from app.core.config import settings + token = create_access_token("user-123", expires_delta=timedelta(minutes=5)) + payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) + assert "exp" in payload + + def test_different_users_get_different_tokens(self): + from app.api.deps import create_access_token + token1 = create_access_token("user-1") + token2 = create_access_token("user-2") + assert token1 != token2 diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..09179ef --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,47 @@ +"""Tests for configuration and settings.""" +import os +import pytest + + +class TestCORSConfig: + """Test CORS origin parsing.""" + + def test_default_cors_origins(self): + from app.core.config import settings + origins = settings.CORS_ORIGIN_LIST + assert isinstance(origins, list) + assert len(origins) >= 1 + # Should not contain wildcard + assert "*" not in origins + + def test_cors_origins_are_strings(self): + from app.core.config import settings + for origin in settings.CORS_ORIGIN_LIST: + assert isinstance(origin, str) + assert origin.startswith("http") + + +class TestSecurityConfig: + """Test security-related configuration defaults.""" + + def test_token_expiry_is_reasonable(self): + from app.core.config import settings + # Should be at most 24 hours (1440 minutes) + assert settings.ACCESS_TOKEN_EXPIRE_MINUTES <= 1440 + + def test_algorithm_is_set(self): + from app.core.config import settings + assert settings.ALGORITHM == "HS256" + + +class TestRateLimitConfig: + """Test rate limiting configuration.""" + + def test_rate_limit_has_default(self): + from app.core.config import settings + assert settings.RATE_LIMIT_PER_MINUTE > 0 + + def test_rate_limit_is_reasonable(self): + from app.core.config import settings + # Should be between 10 and 10000 + assert 10 <= settings.RATE_LIMIT_PER_MINUTE <= 10000 diff --git a/tests/test_factories.py b/tests/test_factories.py new file mode 100644 index 0000000..a87df31 --- /dev/null +++ b/tests/test_factories.py @@ -0,0 +1,131 @@ +"""Tests for factory classes (chunker, retriever, reranker, ingestor). + +Ingestor and reranker factories depend on heavy ML libraries (PyPDF2, torch, +sentence-transformers). These tests mock those dependencies so they run in +any environment, including CI without GPU or ML packages. +""" +import sys +import pytest +from unittest.mock import MagicMock, patch + +from app.db.models.knowledge_base import DocumentType + + +# --------------------------------------------------------------------------- +# ChunkerFactory — no heavy deps, tests directly +# --------------------------------------------------------------------------- +class TestChunkerFactory: + def test_create_chunker_returns_multi_level(self): + from app.services.rag.chunker.chunker_factory import ChunkerFactory + from app.services.rag.chunker.chunker import MultiLevelChunker + chunker = ChunkerFactory.create_chunker(DocumentType.PDF) + assert isinstance(chunker, MultiLevelChunker) + + def test_create_from_metadata_uses_document_type(self): + from app.services.rag.chunker.chunker_factory import ChunkerFactory + from app.services.rag.chunker.chunker import MultiLevelChunker + metadata = {"document_type": DocumentType.CSV} + chunker = ChunkerFactory.create_chunker_from_metadata(metadata) + assert isinstance(chunker, MultiLevelChunker) + + def test_create_from_metadata_defaults_to_txt(self): + from app.services.rag.chunker.chunker_factory import ChunkerFactory + from app.services.rag.chunker.chunker import MultiLevelChunker + chunker = ChunkerFactory.create_chunker_from_metadata({}) + assert isinstance(chunker, MultiLevelChunker) + + +# --------------------------------------------------------------------------- +# IngestorFactory — needs PyPDF2, pandas etc. We mock the ingestor module. +# --------------------------------------------------------------------------- +class TestIngestorFactory: + @pytest.fixture(autouse=True) + def _mock_ingestors(self): + """Mock the ingestor classes so we don't need ML deps.""" + mock_module = MagicMock() + # Create distinct mock classes so isinstance checks work via identity + mock_module.PDFIngestor = type("PDFIngestor", (), {}) + mock_module.CSVIngestor = type("CSVIngestor", (), {}) + mock_module.MarkdownIngestor = type("MarkdownIngestor", (), {}) + mock_module.ImageIngestor = type("ImageIngestor", (), {}) + mock_module.TextIngestor = type("TextIngestor", (), {}) + mock_module.Ingestor = type("Ingestor", (), {}) + + # Patch sys.modules so the factory import resolves + orig = sys.modules.get("app.services.rag.ingestor.ingestor") + sys.modules["app.services.rag.ingestor.ingestor"] = mock_module + + # Force reimport of factory + if "app.services.rag.ingestor.ingestor_factory" in sys.modules: + del sys.modules["app.services.rag.ingestor.ingestor_factory"] + + yield mock_module + + # Restore + if orig is not None: + sys.modules["app.services.rag.ingestor.ingestor"] = orig + elif "app.services.rag.ingestor.ingestor" in sys.modules: + del sys.modules["app.services.rag.ingestor.ingestor"] + if "app.services.rag.ingestor.ingestor_factory" in sys.modules: + del sys.modules["app.services.rag.ingestor.ingestor_factory"] + + def test_pdf_type(self, _mock_ingestors): + from app.services.rag.ingestor.ingestor_factory import IngestorFactory + IngestorFactory._pdf_ingestor = None + ingestor = IngestorFactory.create_ingestor(DocumentType.PDF) + assert type(ingestor).__name__ == "PDFIngestor" + + def test_csv_type(self, _mock_ingestors): + from app.services.rag.ingestor.ingestor_factory import IngestorFactory + IngestorFactory._csv_ingestor = None + ingestor = IngestorFactory.create_ingestor(DocumentType.CSV) + assert type(ingestor).__name__ == "CSVIngestor" + + def test_txt_type(self, _mock_ingestors): + from app.services.rag.ingestor.ingestor_factory import IngestorFactory + IngestorFactory._text_ingestor = None + ingestor = IngestorFactory.create_ingestor(DocumentType.TXT) + assert type(ingestor).__name__ == "TextIngestor" + + def test_singleton_returns_same_instance(self, _mock_ingestors): + from app.services.rag.ingestor.ingestor_factory import IngestorFactory + IngestorFactory._pdf_ingestor = None + first = IngestorFactory.create_ingestor(DocumentType.PDF) + second = IngestorFactory.create_ingestor(DocumentType.PDF) + assert first is second + + +# --------------------------------------------------------------------------- +# RetrieverFactory — needs pinecone. Mock PineconeRetriever. +# --------------------------------------------------------------------------- +class TestRetrieverFactory: + @pytest.fixture(autouse=True) + def _mock_retriever(self): + mock_pinecone_retriever = MagicMock() + mock_retriever_mod = MagicMock() + mock_retriever_mod.PineconeRetriever = mock_pinecone_retriever + + orig = sys.modules.get("app.services.rag.retriever.pinecone_retriever") + sys.modules["app.services.rag.retriever.pinecone_retriever"] = mock_retriever_mod + + if "app.services.rag.retriever.retriever_factory" in sys.modules: + del sys.modules["app.services.rag.retriever.retriever_factory"] + + yield mock_pinecone_retriever + + if orig is not None: + sys.modules["app.services.rag.retriever.pinecone_retriever"] = orig + elif "app.services.rag.retriever.pinecone_retriever" in sys.modules: + del sys.modules["app.services.rag.retriever.pinecone_retriever"] + if "app.services.rag.retriever.retriever_factory" in sys.modules: + del sys.modules["app.services.rag.retriever.retriever_factory"] + + def test_default_creates_pinecone(self, _mock_retriever): + from app.services.rag.retriever.retriever_factory import RetrieverFactory + RetrieverFactory.create_retriever("kb-123") + _mock_retriever.assert_called_with("kb-123") + + def test_unknown_type_falls_back(self, _mock_retriever): + from app.services.rag.retriever.retriever_factory import RetrieverFactory + RetrieverFactory.create_retriever("kb-789", retriever_type="unknown") + _mock_retriever.assert_called_with("kb-789") diff --git a/tests/test_health.py b/tests/test_health.py new file mode 100644 index 0000000..eefe9f2 --- /dev/null +++ b/tests/test_health.py @@ -0,0 +1,78 @@ +"""Tests for the health and root endpoints using a real TestClient.""" +import sys +from unittest.mock import MagicMock + +import pytest + +# Mock heavy dependencies so we can import app.main without ML libraries. +_MOCKED_MODULES = [ + "aiofiles", "celery", "celery.result", + "pinecone", + "PyPDF2", "markdown", + "PIL", "PIL.Image", "pytesseract", + "docling", "docling.document_converter", + "docling.datamodel", "docling.datamodel.base_models", + "docling.datamodel.pipeline_options", + "torch", "sentence_transformers", "FlagEmbedding", + "sendgrid", "sendgrid.helpers", "sendgrid.helpers.mail", + "google.generativeai", "google.genai", + "openai", "anthropic", "dirtyjson", +] + +for _mod in _MOCKED_MODULES: + sys.modules.setdefault(_mod, MagicMock()) + +# pymysql shim so SQLAlchemy can resolve the mysql dialect +try: + import pymysql + pymysql.install_as_MySQLdb() +except ImportError: + sys.modules.setdefault("MySQLdb", MagicMock()) + +from fastapi.testclient import TestClient +from app.main import app # noqa: E402 — must come after mocks + +client = TestClient(app) + + +class TestHealthEndpoint: + def test_health_returns_200(self): + response = client.get("/health") + assert response.status_code == 200 + + def test_health_response_has_status(self): + data = client.get("/health").json() + assert data["status"] == "healthy" + + def test_health_response_has_service_name(self): + data = client.get("/health").json() + assert data["service"] == "DocBrain" + + def test_health_response_has_version(self): + data = client.get("/health").json() + assert "version" in data + + +class TestRootEndpoint: + def test_root_returns_200(self): + response = client.get("/") + assert response.status_code == 200 + + def test_root_has_message(self): + data = client.get("/").json() + assert "message" in data + assert "DocBrain" in data["message"] + + +class TestAppRoutes: + def test_health_route_exists(self): + routes = [r.path for r in app.routes] + assert "/health" in routes + + def test_auth_routes_exist(self): + routes = [r.path for r in app.routes] + assert "/auth/token" in routes + + def test_knowledge_base_routes_exist(self): + routes = [r.path for r in app.routes] + assert any(r.startswith("/knowledge-bases") for r in routes) diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..d0b1a22 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,48 @@ +"""Tests for database models and enums.""" +import pytest +from app.db.models.knowledge_base import DocumentStatus, DocumentType +from app.db.models.user import UserRole + + +class TestUserRole: + def test_all_roles_exist(self): + assert UserRole.ADMIN == "admin" + assert UserRole.OWNER == "owner" + assert UserRole.USER == "user" + + def test_role_count(self): + assert len(UserRole) == 3 + + +class TestDocumentStatus: + def test_all_statuses_exist(self): + assert DocumentStatus.PENDING == "PENDING" + assert DocumentStatus.PROCESSING == "PROCESSING" + assert DocumentStatus.PROCESSED == "PROCESSED" + assert DocumentStatus.FAILED == "FAILED" + + def test_status_count(self): + assert len(DocumentStatus) == 4 + + +class TestDocumentType: + def test_pdf_type(self): + assert DocumentType.PDF == "application/pdf" + + def test_csv_type(self): + assert DocumentType.CSV == "text/csv" + + def test_image_types_exist(self): + assert DocumentType.JPG == "image/jpeg" + assert DocumentType.PNG == "image/png" + assert DocumentType.GIF == "image/gif" + assert DocumentType.TIFF == "image/tiff" + + def test_text_types_exist(self): + assert DocumentType.TXT == "text/plain" + assert DocumentType.MARKDOWN == "text/markdown" + assert DocumentType.HTML == "text/html" + + def test_all_types_are_strings(self): + for doc_type in DocumentType: + assert isinstance(doc_type.value, str) diff --git a/tests/test_permissions.py b/tests/test_permissions.py new file mode 100644 index 0000000..8ec7133 --- /dev/null +++ b/tests/test_permissions.py @@ -0,0 +1,54 @@ +"""Tests for the RBAC permission system.""" +import pytest +from app.core.permissions import ( + Permission, + ROLE_PERMISSIONS, + get_permissions_for_role, +) +from app.db.models.user import UserRole + + +class TestRolePermissions: + """Verify that each role has the correct permissions.""" + + def test_user_role_has_minimal_permissions(self): + perms = get_permissions_for_role(UserRole.USER) + assert Permission.VIEW_KNOWLEDGE_BASES in perms + assert Permission.CONVERSE_WITH_KNOWLEDGE_BASE in perms + # Users should NOT be able to create/delete + assert Permission.CREATE_KNOWLEDGE_BASE not in perms + assert Permission.DELETE_KNOWLEDGE_BASE not in perms + assert Permission.UPLOAD_DOCUMENT not in perms + assert Permission.MANAGE_SYSTEM not in perms + + def test_owner_role_has_document_permissions(self): + perms = get_permissions_for_role(UserRole.OWNER) + assert Permission.CREATE_KNOWLEDGE_BASE in perms + assert Permission.UPDATE_KNOWLEDGE_BASE in perms + assert Permission.DELETE_KNOWLEDGE_BASE in perms + assert Permission.UPLOAD_DOCUMENT in perms + assert Permission.DELETE_DOCUMENT in perms + assert Permission.VIEW_DOCUMENTS in perms + # Owners should NOT manage system or users + assert Permission.MANAGE_SYSTEM not in perms + assert Permission.CREATE_USER not in perms + assert Permission.DELETE_USER not in perms + + def test_admin_role_has_all_permissions(self): + perms = get_permissions_for_role(UserRole.ADMIN) + for permission in Permission: + assert permission in perms, f"Admin missing {permission}" + + def test_unknown_role_returns_empty(self): + perms = get_permissions_for_role("nonexistent") + assert perms == [] + + def test_owner_inherits_user_permissions(self): + user_perms = set(get_permissions_for_role(UserRole.USER)) + owner_perms = set(get_permissions_for_role(UserRole.OWNER)) + assert user_perms.issubset(owner_perms) + + def test_admin_inherits_owner_permissions(self): + owner_perms = set(get_permissions_for_role(UserRole.OWNER)) + admin_perms = set(get_permissions_for_role(UserRole.ADMIN)) + assert owner_perms.issubset(admin_perms) diff --git a/tests/test_rate_limiter.py b/tests/test_rate_limiter.py new file mode 100644 index 0000000..eadf7a0 --- /dev/null +++ b/tests/test_rate_limiter.py @@ -0,0 +1,78 @@ +"""Tests for the rate limiting middleware.""" +import time +import pytest +from collections import defaultdict +from unittest.mock import MagicMock, AsyncMock + +from app.core.middleware import RateLimitMiddleware + + +class TestRateLimitMiddleware: + def _make_middleware(self, rpm: int = 5): + app = MagicMock() + mw = RateLimitMiddleware(app, requests_per_minute=rpm) + return mw + + def test_cleanup_removes_old_timestamps(self): + mw = self._make_middleware() + now = time.time() + timestamps = [now - 120, now - 90, now - 30, now - 10, now] + result = mw._cleanup(timestamps, now) + # Only the last 3 should remain (within 60s window) + assert len(result) == 3 + + def test_cleanup_keeps_recent_timestamps(self): + mw = self._make_middleware() + now = time.time() + timestamps = [now - 5, now - 3, now - 1] + result = mw._cleanup(timestamps, now) + assert len(result) == 3 + + def test_cleanup_empty_list(self): + mw = self._make_middleware() + assert mw._cleanup([], time.time()) == [] + + def test_get_client_ip_from_direct_connection(self): + mw = self._make_middleware() + request = MagicMock() + request.headers = {} + request.client.host = "192.168.1.1" + assert mw._get_client_ip(request) == "192.168.1.1" + + def test_get_client_ip_from_forwarded_header(self): + mw = self._make_middleware() + request = MagicMock() + request.headers = {"x-forwarded-for": "10.0.0.1, 10.0.0.2"} + assert mw._get_client_ip(request) == "10.0.0.1" + + @pytest.mark.asyncio + async def test_exempt_paths_are_not_limited(self): + mw = self._make_middleware(rpm=1) + request = MagicMock() + request.url.path = "/health" + call_next = AsyncMock(return_value=MagicMock(status_code=200)) + + # Should pass through even with rpm=1 + for _ in range(5): + response = await mw.dispatch(request, call_next) + assert call_next.call_count == 5 + + @pytest.mark.asyncio + async def test_rate_limit_blocks_excess_requests(self): + mw = self._make_middleware(rpm=3) + request = MagicMock() + request.url.path = "/api/test" + request.headers = {} + request.client.host = "1.2.3.4" + call_next = AsyncMock(return_value=MagicMock(status_code=200)) + + responses = [] + for _ in range(5): + resp = await mw.dispatch(request, call_next) + responses.append(resp) + + # First 3 should pass, last 2 should be 429 + status_codes = [r.status_code for r in responses] + assert status_codes[:3] == [200, 200, 200] + assert status_codes[3] == 429 + assert status_codes[4] == 429 diff --git a/tests/test_schemas.py b/tests/test_schemas.py new file mode 100644 index 0000000..a730a13 --- /dev/null +++ b/tests/test_schemas.py @@ -0,0 +1,68 @@ +"""Tests for Pydantic schemas validation.""" +import pytest +from pydantic import ValidationError + +from app.schemas.user import UserCreate, UserResponse, UserUpdate +from app.db.models.user import UserRole + + +class TestUserCreate: + def test_valid_user(self): + user = UserCreate( + email="test@example.com", + full_name="Test User", + password="securepassword123", + ) + assert user.email == "test@example.com" + assert user.role == UserRole.USER # default + + def test_invalid_email_raises(self): + with pytest.raises(ValidationError): + UserCreate( + email="not-an-email", + full_name="Test", + password="password", + ) + + def test_missing_password_raises(self): + with pytest.raises(ValidationError): + UserCreate( + email="test@example.com", + full_name="Test", + ) + + def test_custom_role(self): + user = UserCreate( + email="admin@example.com", + full_name="Admin", + password="password", + role=UserRole.ADMIN, + ) + assert user.role == UserRole.ADMIN + + +class TestUserUpdate: + def test_all_fields_optional(self): + update = UserUpdate() + assert update.email is None + assert update.full_name is None + assert update.password is None + + def test_partial_update(self): + update = UserUpdate(full_name="New Name") + assert update.full_name == "New Name" + assert update.email is None + + +class TestUserResponse: + def test_from_attributes(self): + resp = UserResponse( + id="user-123", + email="test@example.com", + full_name="Test User", + role=UserRole.USER, + is_active=True, + hashed_password="hashed_secret", + ) + assert resp.id == "user-123" + assert resp.is_active is True