From b16a73e494b24608870d808ce72f6ee06081c421 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 30 Oct 2025 16:56:50 -0400 Subject: [PATCH 1/3] fix project structure. add toml scan support --- .github/workflows/securefix.yml | 6 +- .github/workflows/test.yml | 7 +- .gitignore | 2 +- README.md | 126 ++++-- bandit.yml | 2 +- cve/scanner.py | 31 -- pyproject.toml | 121 +++++ requirements-dev.txt | 2 - requirements.txt | Bin 914 -> 0 bytes {cve => securefix}/__init__.py | 0 securefix.py => securefix/cli.py | 21 +- .../corpus_downloader.py | 2 +- {remediation => securefix/cve}/__init__.py | 0 {cve => securefix/cve}/db.py | 2 +- securefix/cve/scanner.py | 154 +++++++ models.py => securefix/models.py | 0 {sast => securefix/remediation}/__init__.py | 0 .../remediation}/config.py | 0 .../remediation}/corpus_builder.py | 5 +- .../remediation}/fix_cache.py | 0 .../remediation}/fix_knowledge_store.py | 2 +- .../remediation}/llm_factory.py | 0 .../remediation}/markdown_processor.py | 0 .../remediation}/remediation_engine.py | 9 +- .../remediation}/vulnerability_retriever.py | 2 +- securefix/sast/__init__.py | 0 {sast => securefix/sast}/bandit_mapper.py | 2 +- {sast => securefix/sast}/bandit_scanner.py | 4 +- tests/conftest.py | 2 +- tests/test_bandit_scanner.py | 24 +- tests/test_corpus_builder.py | 17 +- tests/test_cve_scanner.py | 427 ++++++++++++------ tests/test_fix_cache.py | 5 +- tests/test_fix_knowledge_store.py | 14 +- tests/test_llm_factory.py | 9 +- tests/test_remediation_engine.py | 12 +- tests/test_securefix.py | 6 +- tests/test_vulnerability_retriever.py | 17 +- 38 files changed, 729 insertions(+), 304 deletions(-) delete mode 100644 cve/scanner.py create mode 100644 pyproject.toml delete mode 100644 requirements-dev.txt delete mode 100644 requirements.txt rename {cve => securefix}/__init__.py (100%) rename securefix.py => securefix/cli.py (96%) rename corpus_downloader.py => securefix/corpus_downloader.py (99%) rename {remediation => securefix/cve}/__init__.py (100%) rename {cve => securefix/cve}/db.py (94%) create mode 100644 securefix/cve/scanner.py rename models.py => securefix/models.py (100%) rename {sast => securefix/remediation}/__init__.py (100%) rename {remediation => securefix/remediation}/config.py (100%) rename {remediation => securefix/remediation}/corpus_builder.py (99%) rename {remediation => securefix/remediation}/fix_cache.py (100%) rename {remediation => securefix/remediation}/fix_knowledge_store.py (94%) rename {remediation => securefix/remediation}/llm_factory.py (100%) rename {remediation => securefix/remediation}/markdown_processor.py (100%) rename {remediation => securefix/remediation}/remediation_engine.py (97%) rename {remediation => securefix/remediation}/vulnerability_retriever.py (99%) create mode 100644 securefix/sast/__init__.py rename {sast => securefix/sast}/bandit_mapper.py (98%) rename {sast => securefix/sast}/bandit_scanner.py (95%) diff --git a/.github/workflows/securefix.yml b/.github/workflows/securefix.yml index 272eac4..bb18069 100644 --- a/.github/workflows/securefix.yml +++ b/.github/workflows/securefix.yml @@ -12,15 +12,13 @@ jobs: uses: actions/setup-python@v5 with: python-version: '3.11' - cache: 'pip' - name: Install SecureFix run: | - pip install -r requirements.txt - pip install bandit==1.7.5 + pip install -e . - name: Run scan - run: python securefix.py scan . --output results.json + run: securefix scan . --output results.json continue-on-error: true - name: Upload results diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f3e97a5..3500cee 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -20,15 +20,12 @@ jobs: uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - cache: 'pip' - name: Install dependencies run: | python -m pip install --upgrade pip - pip install pytest pytest-cov - if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi + pip install -e ".[dev]" - name: Run tests with pytest run: | - pytest tests/ -v --cov=. --cov-report=term-missing \ No newline at end of file + pytest tests/ -v --cov=securefix --cov-report=term-missing \ No newline at end of file diff --git a/.gitignore b/.gitignore index f5baf1c..bab208e 100644 --- a/.gitignore +++ b/.gitignore @@ -64,7 +64,7 @@ chroma_db/ model_cache/ # Security corpus -remediation/corpus/ +securefix/remediation/corpus/ # LLM caches .ollama/ diff --git a/README.md b/README.md index 8848257..b89749c 100644 --- a/README.md +++ b/README.md @@ -40,16 +40,59 @@ SecureFix bridges rule-based precision and AI-driven guidance through two core c ## Installation +### From Source + ```bash +# Clone the repository git clone https://github.com/hakal/securefix.git cd securefix -python -m venv venv -source venv/bin/activate # Windows: venv\Scripts\activate -pip install -r requirements.txt + +# Install with pip (recommended) +pip install -e . + +# Or install with development dependencies +pip install -e ".[dev]" + +# Or install with all optional dependencies +pip install -e ".[all]" +``` + +### Optional Dependencies + +```bash +# Install with LlamaCPP support (for local model inference) +pip install -e ".[llamacpp]" + +# Install development tools (pytest, coverage) +pip install -e ".[dev]" +``` + +## Configuration + +### Environment Variables + +Create a `.env` file in the project root: + +```bash +# For Google Gemini support +GOOGLE_API_KEY=your_api_key_here + +# Optional: Default model configuration +MODEL_NAME=llama3.2:3b ``` ### LLM Setup +**Ollama (Local - Default):** +- Install Ollama: https://ollama.com/ +- Pull a model: `ollama pull llama3.2:3b` +- No API key required + +**Google Gemini (Cloud):** +- Set `GOOGLE_API_KEY` in `.env` +- Use `--llm-mode google` flag +- Requires internet connection + **Model Recommendations** **For best results:** @@ -77,18 +120,19 @@ echo "GOOGLE_API_KEY=your_key_here" > .env ### Build Knowledge Base (One-time setup) First, ingest your security corpus to build the vector database: + ```bash -# Use this script, or source your own -python corpus_downloader.py --corpus-path ./remediation/corpus +# Download security corpus (use this script, or source your own) +python securefix/corpus_downloader.py --corpus-path ./remediation/corpus -# Use default corpus location (./remediation/corpus) -python securefix.py ingest +# Build vector database from corpus +securefix ingest # Or specify custom corpus path -python securefix.py ingest --corpus-path /path/to/corpus +securefix ingest --corpus-path /path/to/corpus # Rebuild existing database -python securefix.py ingest --rebuild +securefix ingest --rebuild ``` **Supported corpus formats:** @@ -100,41 +144,48 @@ python securefix.py ingest --rebuild ```bash # Scan a single file -python securefix.py scan path/to/code.py +securefix scan path/to/code.py # Scan a directory -python securefix.py scan src/ +securefix scan src/ # Scan with dependencies -python securefix.py scan src/ --dependencies requirements.txt +securefix scan src/ --dependencies requirements.txt # Custom output file -python securefix.py scan src/ -d requirements.txt -o my_report.json +securefix scan src/ -d requirements.txt -o my_report.json ``` ### Remediation ```bash # Generate fix suggestions -python securefix.py fix report.json --output fixes.json +securefix fix report.json --output fixes.json + +# Interactive mode (review and approve each fix) +securefix fix report.json --interactive -# Interactive mode -python securefix.py fix report.json --interactive +# Choose LLM backend +securefix fix report.json --llm-mode local # Ollama (default) +securefix fix report.json --llm-mode google # Google Gemini -# Local or cloud -python securefix.py fix report.json --llm-mode local|google +# Specify model name +securefix fix report.json --model-name llama3.2:3b -# Choose model -python securefix.py fix report.json --model-name qwen3:4b +# Disable semantic caching +securefix fix report.json --no-cache -# Disable cache -python securefix.py fix report.json --no-cache +# Custom vector database location +securefix fix report.json --persist-dir ./my_chroma_db -# Vector DB location -python securefix.py fix report.json --persist-dir /remediation/chroma_db +# Filter by severity (only fix high/critical vulnerabilities) +securefix fix report.json --severity-filter high -# Filter by severity -python securefix.py fix report.json --severity-filter +# Only remediate SAST findings (skip CVE findings) +securefix fix report.json --sast-only + +# Only remediate CVE findings (skip SAST findings) +securefix fix report.json --cve-only ``` ### Output Format @@ -267,6 +318,27 @@ python securefix.py fix report.json --severity-filter } ``` + +## Development + +### Running Tests + +```bash +# Install development dependencies +pip install -e ".[dev]" + +# Run all tests +pytest + +# Run with coverage +pytest --cov=securefix --cov-report=html + +# Run specific test categories +pytest -m unit # Unit tests only +pytest -m integration # Integration tests only +pytest -m "not slow" # Skip slow tests +``` + ## Technical Approach ### Detection Pipeline @@ -311,7 +383,7 @@ pytest --cov=securefix tests/ - ollama: Local LLM support - click: CLI framework -See `requirements.txt` && `requirements-dev.txt` for complete dependency list. +See `pyproject.toml` for complete dependency list ## References diff --git a/bandit.yml b/bandit.yml index 31c3ff7..55d0e42 100644 --- a/bandit.yml +++ b/bandit.yml @@ -12,7 +12,7 @@ exclude_dirs: - chroma_db - model_cache - remediation/corpus - - vulnerable # Vulnerable code used for testing +# - vulnerable # Vulnerable code used for testing - .pytest_cache - __pycache__ - .idea diff --git a/cve/scanner.py b/cve/scanner.py deleted file mode 100644 index 0896e40..0000000 --- a/cve/scanner.py +++ /dev/null @@ -1,31 +0,0 @@ -from cve.db import query_osv -from models import OSVFinding - - -def scan_requirements(requirements_file): - findings = [] - - # Try different encodings - encodings = ['utf-8-sig', 'utf-16', 'utf-16-le', 'utf-16-be', 'latin-1'] - - content = None - for encoding in encodings: - try: - with open(requirements_file, encoding=encoding) as f: - content = f.read() - break - except (UnicodeDecodeError, UnicodeError): - continue - - if content is None: - raise ValueError(f"Could not decode {requirements_file} with any known encoding") - - for line in content.splitlines(): - line = line.strip() - if '==' in line: - package, version = line.split('==') - cves = query_osv(package, version) - if cves: - findings.append(OSVFinding(package, version, cves, requirements_file)) - - return findings \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..d3eecca --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,121 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "securefix" +version = "0.1.0" +description = "Static Application Security Testing with Smart Remediation" +readme = "README.md" +requires-python = ">=3.8" +authors = [ + {name = "HakAl" } +] +keywords = ["security", "sast", "vulnerability", "remediation", "llm"] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Topic :: Security", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] + +dependencies = [ + "requests>=2.32.0,<3.0.0", + "click>=8.3.0,<9.0.0", + "langchain>=0.1.0", + "langchain-community>=0.3.0,<0.4.0", + "langchain-chroma>=0.1.0", + "langchain-core>=0.3.0,<0.4.0", + "langchain-huggingface>=0.1.0", + "langchain-ollama>=0.1.0", + "langchain-google-genai>=2.0.0,<3.0.0", + "sentence-transformers>=2.2.0", + "rank-bm25>=0.2.1", + "chromadb>=0.4.0", + "tqdm>=4.65.0", + "numpy>=1.21.0", + "nltk>=3.8.0", + "google-generativeai", + "pyyaml>=6.0", + "markdown>=3.4.0", + "pydantic>=2.0,<3.0", + "python-dotenv>=1.0.0", + "bandit>=1.8.0,<2.0.0", + "json-repair>=0.52.0,<1.0.0", + "tomli>=2.0.0; python_version < '3.11'", + "packaging>=21.0", +] + +[project.optional-dependencies] +# Development dependencies +dev = [ + "pytest>=7.0.0", + "pytest-cov>=4.0.0", +] + +# LlamaCPP support for local model inference +llamacpp = [ + "llama-cpp-python>=0.2.0", +] + +# All optional dependencies +all = [ + "securefix[dev,llamacpp]", +] + +[project.scripts] +securefix = "securefix.cli:cli" + +[project.urls] +Homepage = "https://github.com/HakAl/securefix" +Repository = "https://github.com/HakAl/securefix" +Issues = "https://github.com/HakAl/securefix/issues" + +[tool.setuptools.packages.find] +where = ["."] +include = ["securefix*"] +exclude = ["tests*", "docs*"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] +addopts = [ + "-v", + "--strict-markers", + "--cov=securefix", + "--cov-report=term-missing", + "--cov-report=html", +] +markers = [ + "slow: marks tests as slow (deselect with '-m \"not slow\"')", + "integration: marks tests as integration tests", + "unit: marks tests as unit tests", + "requires_nltk: marks tests that require NLTK data", + "requires_api: marks tests that require API keys", +] + +[tool.coverage.run] +source = ["securefix"] +omit = [ + "*/tests/*", + "*/test_*.py", + "*/__pycache__/*", + "*/venv/*", +] + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "def __repr__", + "raise AssertionError", + "raise NotImplementedError", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", +] \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt deleted file mode 100644 index 7571d09..0000000 --- a/requirements-dev.txt +++ /dev/null @@ -1,2 +0,0 @@ -pytest>=7.0.0 -pytest-cov>=4.0.0 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index aa5a20fa6b57f10745458bb5c1db218e0437f21c..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 914 zcmb7?-A=+#5QO*I#5b{oA0b|-uVSDrr9xY63x=0hzu7}uj3i>3b6VKh*`3+_`m8Or zciVF})^eVBrZ%#f-*daQ8;fmcIrrA14;~%a*kg+CY-y!sAU1sEd_RL6gPU*y?l&d^ zUs=JkCnKl&(fO4(`y+2-wdXGXQAtXr?kE4{{v(v^QIxrw%+~0N$-holg(K2iaIVY9 zP*YKDM~#SgiX(Een}&KG{BAtn`mL@aysN>N%x?`(P34rk@#>nUPRI|qEjcStQ=9i4 z3Nty)d+;{_*J-%#Sz2}`Azu;wyMU*wP~u!+Z%k8BZNT(w4@xw)Y?5M{vbHI7cF6;& zSc1>(fFe1fd5QbTuL3UdEP4((Uc;W?!pCV3Rb}X0xe2J!S!2%+?%dhq+1_<^S~j5q t9XV(4{Lm{K=CX3mgbrIe3c96j;T^v+3$eV=T}uB_QaerzbDz)5k>B18i0%LY diff --git a/cve/__init__.py b/securefix/__init__.py similarity index 100% rename from cve/__init__.py rename to securefix/__init__.py diff --git a/securefix.py b/securefix/cli.py similarity index 96% rename from securefix.py rename to securefix/cli.py index 0847741..00a3514 100644 --- a/securefix.py +++ b/securefix/cli.py @@ -3,11 +3,11 @@ import json from datetime import datetime from pathlib import Path -import sast.bandit_scanner as bandit_scanner -import cve.scanner as cve_scanner -from models import ScanResult +import securefix.sast.bandit_scanner as bandit_scanner +import securefix.cve.scanner as cve_scanner +from securefix.models import ScanResult from json_repair import repair_json -from remediation.corpus_builder import DocumentProcessor +from securefix.remediation.corpus_builder import DocumentProcessor from typing import List, Dict @@ -20,7 +20,7 @@ def cli(): @cli.command() @click.argument('target', type=click.Path(exists=True)) @click.option('--dependencies', '-d', type=click.Path(exists=True), - help='Path to requirements.txt for CVE scanning') + help='Path to requirements.txt or pyproject.toml for CVE scanning') @click.option('--output', '-o', type=click.Path(), default='report.json', help='Output JSON file (default: report.json)') @click.option('--severity', '-s', type=click.Choice(['low', 'medium', 'high'], case_sensitive=True), @@ -172,10 +172,9 @@ def fix(report, output, interactive, llm_mode, model_name, no_cache, persist_dir severity_filter, sast_only, cve_only): """Generate security fixes for vulnerabilities in REPORT""" import time - from remediation.corpus_builder import DocumentProcessor - from remediation.fix_knowledge_store import DocumentStore - from remediation.llm_factory import LLMFactory - from remediation.remediation_engine import RemediationEngine + from securefix.remediation.corpus_builder import DocumentProcessor + from securefix.remediation.fix_knowledge_store import DocumentStore + from securefix.remediation.remediation_engine import RemediationEngine start_time = time.time() @@ -380,8 +379,8 @@ def fix(report, output, interactive, llm_mode, model_name, no_cache, persist_dir def _configure_llm(mode: str, model_name: str = None): """Configure LLM based on mode and validate availability.""" - from remediation.llm_factory import LLMFactory, check_ollama_available, check_google_api_key - from remediation.config import app_config + from securefix.remediation.llm_factory import LLMFactory, check_ollama_available, check_google_api_key + from securefix.remediation.config import app_config if mode == 'local': if not check_ollama_available(): diff --git a/corpus_downloader.py b/securefix/corpus_downloader.py similarity index 99% rename from corpus_downloader.py rename to securefix/corpus_downloader.py index db3949d..cc993a7 100644 --- a/corpus_downloader.py +++ b/securefix/corpus_downloader.py @@ -169,7 +169,7 @@ def download_corpus(corpus_path, skip_cwe, skip_owasp, skip_pypa): if success_count > 0: click.echo(f"\nāœ“ Corpus downloaded to {corpus_path}") click.echo("\nNext steps:") - click.echo(" 1. Run: python securefix.py ingest") + click.echo(" 1. Run: python cli.py ingest") click.echo(" 2. This will build the vector database (takes 5-10 minutes)") click.echo(" 3. Then you can scan and fix vulnerabilities!") else: diff --git a/remediation/__init__.py b/securefix/cve/__init__.py similarity index 100% rename from remediation/__init__.py rename to securefix/cve/__init__.py diff --git a/cve/db.py b/securefix/cve/db.py similarity index 94% rename from cve/db.py rename to securefix/cve/db.py index cb4fea0..3df0263 100644 --- a/cve/db.py +++ b/securefix/cve/db.py @@ -1,5 +1,5 @@ import requests -from models import OSVRequest +from securefix.models import OSVRequest def query_osv(package, version): diff --git a/securefix/cve/scanner.py b/securefix/cve/scanner.py new file mode 100644 index 0000000..1e6b446 --- /dev/null +++ b/securefix/cve/scanner.py @@ -0,0 +1,154 @@ +from securefix.cve.db import query_osv +from securefix.models import OSVFinding +import sys + +# Import tomllib/tomli at module level for easier testing +if sys.version_info >= (3, 11): + import tomllib +else: + try: + import tomli as tomllib + except ImportError: + tomllib = None # Will be checked in scan_pyproject + + +def scan_requirements(requirements_file): + """Scan a requirements.txt file for CVE vulnerabilities.""" + findings = [] + + # Try different encodings + encodings = ['utf-8-sig', 'utf-16', 'utf-16-le', 'utf-16-be', 'latin-1'] + + content = None + for encoding in encodings: + try: + with open(requirements_file, encoding=encoding) as f: + content = f.read() + break + except (UnicodeDecodeError, UnicodeError): + continue + + if content is None: + raise ValueError(f"Could not decode {requirements_file} with any known encoding") + + for line in content.splitlines(): + line = line.strip() + if '==' in line: + package, version = line.split('==') + cves = query_osv(package, version) + if cves: + findings.append(OSVFinding(package, version, cves, requirements_file)) + + return findings + + +def scan_pyproject(pyproject_file): + """Scan a pyproject.toml file for CVE vulnerabilities.""" + findings = [] + + # Check if tomllib is available + if tomllib is None: + raise ImportError( + "tomli is required for Python < 3.11. Install with: pip install tomli" + ) + + try: + with open(pyproject_file, 'rb') as f: + pyproject_data = tomllib.load(f) + except FileNotFoundError: + raise FileNotFoundError(f"pyproject.toml not found at {pyproject_file}") + except Exception as e: + raise ValueError(f"Error parsing pyproject.toml: {e}") + + # Extract dependencies from [project.dependencies] + dependencies = pyproject_data.get('project', {}).get('dependencies', []) + + for dep_spec in dependencies: + # Parse dependency specification (e.g., "requests>=2.32.0,<3.0.0") + package_name, version = _parse_dependency_spec(dep_spec) + + if package_name and version: + cves = query_osv(package_name, version) + if cves: + findings.append(OSVFinding(package_name, version, cves, pyproject_file)) + + return findings + + +def _parse_dependency_spec(dep_spec): + """ + Parse a PEP 508 dependency specification to extract package name and version. + + Examples: + "requests==2.32.0" -> ("requests", "2.32.0") + "requests>=2.32.0,<3.0.0" -> ("requests", "2.32.0") + "click>=8.3.0,<9.0.0" -> ("click", "8.3.0") + + Returns: + Tuple of (package_name, version) or (None, None) if parsing fails + """ + try: + from packaging.requirements import Requirement + from packaging.specifiers import SpecifierSet + + req = Requirement(dep_spec) + package_name = req.name + + # Extract version from specifiers + if req.specifier: + # Get the minimum version from the specifier set + # For specs like ">=2.32.0,<3.0.0", we want to check the minimum version + version = _extract_version_from_specifier(req.specifier) + return package_name, version + else: + # No version specifier, can't check CVEs + return None, None + + except Exception as e: + print(f"Warning: Could not parse dependency spec '{dep_spec}': {e}") + return None, None + + +def _extract_version_from_specifier(specifier_set): + """ + Extract a version string from a SpecifierSet. + + Prioritizes: + 1. Exact version (==) + 2. Minimum version (>=, >) + 3. Maximum version (<, <=) + """ + for spec in specifier_set: + operator = spec.operator + version = spec.version + + # Exact match - use this version + if operator == "==": + return version + + # Greater than or equal - use minimum version + if operator in (">=", ">"): + return version + + # If no >= or ==, look for < or <= + for spec in specifier_set: + if spec.operator in ("<", "<="): + return spec.version + + return None + + +def scan_dependencies(dependency_file): + """ + Auto-detect and scan either requirements.txt or pyproject.toml. + + Args: + dependency_file: Path to requirements.txt or pyproject.toml + + Returns: + List of OSVFinding objects + """ + if dependency_file.endswith('.toml'): + return scan_pyproject(dependency_file) + else: + return scan_requirements(dependency_file) \ No newline at end of file diff --git a/models.py b/securefix/models.py similarity index 100% rename from models.py rename to securefix/models.py diff --git a/sast/__init__.py b/securefix/remediation/__init__.py similarity index 100% rename from sast/__init__.py rename to securefix/remediation/__init__.py diff --git a/remediation/config.py b/securefix/remediation/config.py similarity index 100% rename from remediation/config.py rename to securefix/remediation/config.py diff --git a/remediation/corpus_builder.py b/securefix/remediation/corpus_builder.py similarity index 99% rename from remediation/corpus_builder.py rename to securefix/remediation/corpus_builder.py index fb68f42..22ab6d1 100644 --- a/remediation/corpus_builder.py +++ b/securefix/remediation/corpus_builder.py @@ -1,5 +1,4 @@ import os -import json import logging import time import yaml @@ -8,8 +7,8 @@ from contextlib import contextmanager from typing import List, Dict, NamedTuple, Optional, Tuple -from remediation.config import app_config -from remediation.markdown_processor import process_markdown_file +from securefix.remediation.config import app_config +from securefix.remediation.markdown_processor import process_markdown_file from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_chroma import Chroma from langchain_core.documents import Document diff --git a/remediation/fix_cache.py b/securefix/remediation/fix_cache.py similarity index 100% rename from remediation/fix_cache.py rename to securefix/remediation/fix_cache.py diff --git a/remediation/fix_knowledge_store.py b/securefix/remediation/fix_knowledge_store.py similarity index 94% rename from remediation/fix_knowledge_store.py rename to securefix/remediation/fix_knowledge_store.py index 0a2e508..7e9deae 100644 --- a/remediation/fix_knowledge_store.py +++ b/securefix/remediation/fix_knowledge_store.py @@ -1,5 +1,5 @@ from typing import Any, List, Optional -from remediation.vulnerability_retriever import create_hybrid_retrieval_pipeline +from securefix.remediation.vulnerability_retriever import create_hybrid_retrieval_pipeline from langchain_chroma import Chroma from langchain_core.documents import Document from langchain_core.retrievers import BaseRetriever diff --git a/remediation/llm_factory.py b/securefix/remediation/llm_factory.py similarity index 100% rename from remediation/llm_factory.py rename to securefix/remediation/llm_factory.py diff --git a/remediation/markdown_processor.py b/securefix/remediation/markdown_processor.py similarity index 100% rename from remediation/markdown_processor.py rename to securefix/remediation/markdown_processor.py diff --git a/remediation/remediation_engine.py b/securefix/remediation/remediation_engine.py similarity index 97% rename from remediation/remediation_engine.py rename to securefix/remediation/remediation_engine.py index ac6ed30..2aab72c 100644 --- a/remediation/remediation_engine.py +++ b/securefix/remediation/remediation_engine.py @@ -1,14 +1,13 @@ -import re import time -from remediation.fix_knowledge_store import DocumentStore +from securefix.remediation.fix_knowledge_store import DocumentStore from langchain.callbacks.base import BaseCallbackHandler from langchain.chains import RetrievalQA from langchain_chroma import Chroma from langchain_core.documents import Document from langchain_core.runnables import RunnableSerializable -from remediation.llm_factory import LLMConfig -from remediation.fix_cache import SemanticQueryCache -from typing import Any, Dict, List, Optional, Callable +from securefix.remediation.llm_factory import LLMConfig +from securefix.remediation.fix_cache import SemanticQueryCache +from typing import Any, Dict, Optional, Callable # --- NLTK Integration for Preprocessing --- try: diff --git a/remediation/vulnerability_retriever.py b/securefix/remediation/vulnerability_retriever.py similarity index 99% rename from remediation/vulnerability_retriever.py rename to securefix/remediation/vulnerability_retriever.py index b58fba5..13e1341 100644 --- a/remediation/vulnerability_retriever.py +++ b/securefix/remediation/vulnerability_retriever.py @@ -1,5 +1,5 @@ import numpy as np -from remediation.config import app_config +from securefix.remediation.config import app_config from langchain_chroma import Chroma from langchain_core.retrievers import BaseRetriever from langchain_core.documents import Document diff --git a/securefix/sast/__init__.py b/securefix/sast/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sast/bandit_mapper.py b/securefix/sast/bandit_mapper.py similarity index 98% rename from sast/bandit_mapper.py rename to securefix/sast/bandit_mapper.py index a3d4706..06a37eb 100644 --- a/sast/bandit_mapper.py +++ b/securefix/sast/bandit_mapper.py @@ -1,5 +1,5 @@ from typing import Dict, Any -from models import Type, Severity, Confidence, Finding +from securefix.models import Type, Severity, Confidence, Finding # This dictionary maps the Bandit 'test_id' BANDIT_ID_TO_TYPE = { diff --git a/sast/bandit_scanner.py b/securefix/sast/bandit_scanner.py similarity index 95% rename from sast/bandit_scanner.py rename to securefix/sast/bandit_scanner.py index bf51ddc..45e35bf 100644 --- a/sast/bandit_scanner.py +++ b/securefix/sast/bandit_scanner.py @@ -2,8 +2,8 @@ import json import os from typing import List, Optional -from models import Finding -from sast.bandit_mapper import convert_bandit_result +from securefix.models import Finding +from securefix.sast.bandit_mapper import convert_bandit_result def _find_bandit_config() -> Optional[str]: diff --git a/tests/conftest.py b/tests/conftest.py index 9af1eb3..ef1e6c7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -273,7 +273,7 @@ def mock_app_config(): @pytest.fixture(autouse=True) def mock_bandit_config_finder(): """Automatically mock config finder for all tests""" - with patch('sast.bandit_scanner._find_bandit_config', return_value=None): + with patch('securefix.sast.bandit_scanner._find_bandit_config', return_value=None): yield diff --git a/tests/test_bandit_scanner.py b/tests/test_bandit_scanner.py index acbd0f3..7ac3897 100644 --- a/tests/test_bandit_scanner.py +++ b/tests/test_bandit_scanner.py @@ -1,8 +1,8 @@ import pytest import json from unittest.mock import patch, MagicMock -from sast.bandit_scanner import scan -from models import Finding, Type, Severity, Confidence +from securefix.sast.bandit_scanner import scan +from securefix.models import Finding, Type, Severity, Confidence @pytest.fixture @@ -46,7 +46,7 @@ class TestBanditScanner: def test_scan_success_with_findings(self, mock_subprocess_success): """Test successful scan with vulnerabilities found""" - with patch('sast.bandit_scanner.subprocess.run', return_value=mock_subprocess_success): + with patch('securefix.sast.bandit_scanner.subprocess.run', return_value=mock_subprocess_success): findings = scan("/test/path", "medium", "medium") assert len(findings) == 2 @@ -71,14 +71,14 @@ def test_scan_no_findings(self): mock_result.stdout = json.dumps({"results": []}) mock_result.stderr = "" - with patch('sast.bandit_scanner.subprocess.run', return_value=mock_result): + with patch('securefix.sast.bandit_scanner.subprocess.run', return_value=mock_result): findings = scan("/test/path", "medium", "medium") assert len(findings) == 0 def test_scan_bandit_not_installed(self): """Test when Bandit is not installed""" - with patch('sast.bandit_scanner.subprocess.run', side_effect=FileNotFoundError()): + with patch('securefix.sast.bandit_scanner.subprocess.run', side_effect=FileNotFoundError()): findings = scan("/test/path", "medium", "medium") assert findings == [] @@ -90,7 +90,7 @@ def test_scan_bandit_failure(self): mock_result.stdout = "" mock_result.stderr = "Bandit error" - with patch('sast.bandit_scanner.subprocess.run', return_value=mock_result): + with patch('securefix.sast.bandit_scanner.subprocess.run', return_value=mock_result): findings = scan("/test/path", "medium", "medium") assert findings == [] @@ -102,14 +102,14 @@ def test_scan_invalid_json(self): mock_result.stdout = "not valid json" mock_result.stderr = "" - with patch('sast.bandit_scanner.subprocess.run', return_value=mock_result): + with patch('securefix.sast.bandit_scanner.subprocess.run', return_value=mock_result): findings = scan("/test", "medium", "medium") # ← Add these two parameters assert findings == [] def test_scan_command_construction(self, mock_subprocess_success): """Test that the correct command is constructed""" - with patch('sast.bandit_scanner.subprocess.run', return_value=mock_subprocess_success) as mock_run: + with patch('securefix.sast.bandit_scanner.subprocess.run', return_value=mock_subprocess_success) as mock_run: scan("/my/test/path", "medium", "medium") # Check the command that was called - include the additional parameters @@ -131,14 +131,14 @@ def test_scan_command_construction(self, mock_subprocess_success): def test_scan_handles_file_path(self, mock_subprocess_success): """Test scanning a single file""" - with patch('sast.bandit_scanner.subprocess.run', return_value=mock_subprocess_success): + with patch('securefix.sast.bandit_scanner.subprocess.run', return_value=mock_subprocess_success): findings = scan("/path/to/file.py", "medium", "medium") assert isinstance(findings, list) def test_scan_handles_directory_path(self, mock_subprocess_success): """Test scanning a directory""" - with patch('sast.bandit_scanner.subprocess.run', return_value=mock_subprocess_success): + with patch('securefix.sast.bandit_scanner.subprocess.run', return_value=mock_subprocess_success): findings = scan("/path/to/directory", "medium", "medium") assert isinstance(findings, list) @@ -160,7 +160,7 @@ def test_scan_snippet_cleaning(self): }) mock_result.stderr = "" - with patch('sast.bandit_scanner.subprocess.run', return_value=mock_result): + with patch('securefix.sast.bandit_scanner.subprocess.run', return_value=mock_result): findings = scan("/test", "medium", "medium") # Snippet should not contain line number @@ -185,7 +185,7 @@ def test_scan_multiline_snippet(self): }) mock_result.stderr = "" - with patch('sast.bandit_scanner.subprocess.run', return_value=mock_result): + with patch('securefix.sast.bandit_scanner.subprocess.run', return_value=mock_result): findings = scan("/test", "medium", "medium") # Check that multiline snippet is preserved diff --git a/tests/test_corpus_builder.py b/tests/test_corpus_builder.py index c05e084..915e142 100644 --- a/tests/test_corpus_builder.py +++ b/tests/test_corpus_builder.py @@ -10,16 +10,9 @@ """ import os import pytest -from unittest.mock import Mock, patch, MagicMock, call -from pathlib import Path +from unittest.mock import Mock, patch from langchain_core.documents import Document - -from remediation.corpus_builder import ( - DocumentProcessor, - LoadResult, - ProgressEmbeddings, - create_progress_embeddings -) +from securefix.remediation.corpus_builder import DocumentProcessor, LoadResult, ProgressEmbeddings, create_progress_embeddings @pytest.fixture @@ -290,7 +283,7 @@ def test_process_documents_with_failed_files(self, mock_create_vs, mock_load, class TestVectorStore: """Test vector store operations""" - @patch('remediation.corpus_builder.Chroma') + @patch('securefix.remediation.corpus_builder.Chroma') def test_create_vectorstore(self, mock_chroma, doc_processor, sample_security_documents): """Test vector store creation""" mock_vs = Mock() @@ -301,7 +294,7 @@ def test_create_vectorstore(self, mock_chroma, doc_processor, sample_security_do assert vs == mock_vs mock_chroma.from_documents.assert_called_once() - @patch('remediation.corpus_builder.Chroma') + @patch('securefix.remediation.corpus_builder.Chroma') def test_load_existing_vectorstore_success(self, mock_chroma, doc_processor): """Test loading existing vector store""" mock_vs = Mock() @@ -328,7 +321,7 @@ def test_load_existing_vectorstore_not_found(self, doc_processor): assert bm25_index is None assert bm25_chunks is None - @patch('remediation.corpus_builder.Chroma') + @patch('securefix.remediation.corpus_builder.Chroma') def test_load_existing_vectorstore_empty(self, mock_chroma, doc_processor): """Test loading empty vector store""" mock_vs = Mock() diff --git a/tests/test_cve_scanner.py b/tests/test_cve_scanner.py index c6b54f6..1b5c557 100644 --- a/tests/test_cve_scanner.py +++ b/tests/test_cve_scanner.py @@ -1,174 +1,313 @@ from unittest.mock import mock_open, patch, MagicMock -import requests - -from cve.db import query_osv -from cve.scanner import scan_requirements -from models import OSVFinding - - -class TestCheckOSVAPI: - """Tests for the query_osv function""" - - @patch('cve.db.requests.post') - def test_successful_vulnerability_found(self, mock_post): - """Test when vulnerabilities are found""" - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = { - 'vulns': [ - {'id': 'CVE-2021-44228'}, - {'id': 'CVE-2021-45046'} - ] - } - mock_post.return_value = mock_response +import pytest +import sys + +from securefix.cve.scanner import ( + scan_pyproject, + scan_dependencies, + _parse_dependency_spec, + _extract_version_from_specifier +) +from securefix.models import OSVFinding + + +class TestParseDependencySpec: + """Tests for parsing PEP 508 dependency specifications""" + + def test_parse_exact_version(self): + """Test parsing exact version specification""" + package, version = _parse_dependency_spec("requests==2.32.0") + assert package == "requests" + assert version == "2.32.0" + + def test_parse_minimum_version(self): + """Test parsing >= version specification""" + package, version = _parse_dependency_spec("requests>=2.32.0") + assert package == "requests" + assert version == "2.32.0" + + def test_parse_version_range(self): + """Test parsing version range with multiple specifiers""" + package, version = _parse_dependency_spec("requests>=2.32.0,<3.0.0") + assert package == "requests" + assert version == "2.32.0" # Should extract minimum version + + def test_parse_complex_spec(self): + """Test parsing complex specification""" + package, version = _parse_dependency_spec("click>=8.3.0,<9.0.0") + assert package == "click" + assert version == "8.3.0" + + def test_parse_no_version_specifier(self): + """Test parsing dependency without version""" + package, version = _parse_dependency_spec("requests") + assert package is None + assert version is None + + def test_parse_with_extras(self): + """Test parsing dependency with extras""" + package, version = _parse_dependency_spec("requests[security]>=2.32.0") + assert package == "requests" + assert version == "2.32.0" + + def test_parse_invalid_spec(self): + """Test handling of invalid dependency spec""" + package, version = _parse_dependency_spec("invalid>>2.0.0") + assert package is None + assert version is None + + def test_parse_with_environment_marker(self): + """Test parsing dependency with environment marker""" + package, version = _parse_dependency_spec("tomli>=2.0.0; python_version < '3.11'") + assert package == "tomli" + assert version == "2.0.0" + + +class TestExtractVersionFromSpecifier: + """Tests for version extraction from specifier sets""" + + def test_extract_exact_version(self): + """Test extracting exact version (==)""" + from packaging.specifiers import SpecifierSet + spec_set = SpecifierSet("==2.32.0") + version = _extract_version_from_specifier(spec_set) + assert version == "2.32.0" + + def test_extract_minimum_version(self): + """Test extracting from >= specifier""" + from packaging.specifiers import SpecifierSet + spec_set = SpecifierSet(">=2.32.0,<3.0.0") + version = _extract_version_from_specifier(spec_set) + assert version == "2.32.0" + + def test_extract_from_less_than_only(self): + """Test extracting from < specifier when no >= exists""" + from packaging.specifiers import SpecifierSet + spec_set = SpecifierSet("<3.0.0") + version = _extract_version_from_specifier(spec_set) + assert version == "3.0.0" + + def test_extract_prioritizes_exact(self): + """Test that == is prioritized over other operators""" + from packaging.specifiers import SpecifierSet + spec_set = SpecifierSet(">=2.0.0,==2.32.0,<3.0.0") + version = _extract_version_from_specifier(spec_set) + assert version == "2.32.0" + + +class TestScanPyproject: + """Tests for scanning pyproject.toml files""" + + @patch('securefix.cve.scanner.query_osv') + def test_scan_pyproject_with_vulnerabilities(self, mock_query_osv): + """Test scanning pyproject.toml with vulnerable dependencies""" + pyproject_content = b""" +[project] +dependencies = [ + "flask==0.12", + "requests>=2.32.0,<3.0.0", + "django==2.2.0" +] +""" + + def query_side_effect(package, version): + if package == 'flask': + return ['CVE-2018-1000656'] + elif package == 'django': + return ['CVE-2019-14234'] + return [] - result = query_osv('log4j', '2.14.1') + mock_query_osv.side_effect = query_side_effect + + # Mock tomllib/tomli load + if sys.version_info >= (3, 11): + import tomllib + mock_module = 'tomllib' + else: + mock_module = 'tomli' + + with patch('builtins.open', mock_open(read_data=pyproject_content)): + with patch(f'securefix.cve.scanner.{mock_module}.load') as mock_load: + mock_load.return_value = { + 'project': { + 'dependencies': [ + 'flask==0.12', + 'requests>=2.32.0,<3.0.0', + 'django==2.2.0' + ] + } + } + findings = scan_pyproject('pyproject.toml') - assert result == ['CVE-2021-44228', 'CVE-2021-45046'] - mock_post.assert_called_once() + assert len(findings) == 2 + assert findings[0].package == 'flask' + assert findings[0].version == '0.12' + assert findings[0].cves == ['CVE-2018-1000656'] + assert findings[1].package == 'django' + assert findings[1].version == '2.2.0' - @patch('cve.db.requests.post') - def test_no_vulnerabilities_found(self, mock_post): - """Test when no vulnerabilities are found""" - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = {'vulns': []} - mock_post.return_value = mock_response + @patch('securefix.cve.scanner.query_osv') + def test_scan_pyproject_no_vulnerabilities(self, mock_query_osv): + """Test scanning pyproject.toml with no vulnerabilities""" + mock_query_osv.return_value = [] + + pyproject_data = { + 'project': { + 'dependencies': [ + 'requests>=2.32.0', + 'click>=8.3.0' + ] + } + } - result = query_osv('safe-package', '1.0.0') + with patch('builtins.open', mock_open()): + with patch('securefix.cve.scanner.tomllib.load', return_value=pyproject_data): + findings = scan_pyproject('pyproject.toml') - assert result == [] + assert len(findings) == 0 - @patch('cve.db.requests.post') - def test_api_error_returns_empty_list(self, mock_post): - """Test that API errors are handled gracefully""" - mock_post.side_effect = requests.RequestException("API Error") + @patch('securefix.cve.scanner.query_osv') + def test_scan_pyproject_no_dependencies(self, mock_query_osv): + """Test scanning pyproject.toml with no dependencies section""" + pyproject_data = {'project': {}} - result = query_osv('flask', '0.12') + with patch('builtins.open', mock_open()): + with patch('securefix.cve.scanner.tomllib.load', return_value=pyproject_data): + findings = scan_pyproject('pyproject.toml') - assert result == [] + assert len(findings) == 0 + mock_query_osv.assert_not_called() + + @patch('securefix.cve.scanner.query_osv') + def test_scan_pyproject_skips_unparseable_deps(self, mock_query_osv): + """Test that unparseable dependencies are skipped gracefully""" + mock_query_osv.return_value = [] + + if sys.version_info >= (3, 11): + mock_module = 'tomllib' + else: + mock_module = 'tomli' + + with patch('builtins.open', mock_open()): + with patch(f'securefix.cve.scanner.{mock_module}.load') as mock_load: + mock_load.return_value = { + 'project': { + 'dependencies': [ + 'requests>=2.32.0', # Valid + 'invalid>>spec', # Invalid + 'click' # No version + ] + } + } + findings = scan_pyproject('pyproject.toml') + + # Should only check 'requests' (the valid one with version) + assert mock_query_osv.call_count == 1 + mock_query_osv.assert_called_with('requests', '2.32.0') + + def test_scan_pyproject_file_not_found(self): + """Test handling of missing pyproject.toml""" + with pytest.raises(FileNotFoundError): + scan_pyproject('nonexistent.toml') + + @patch('securefix.cve.scanner.query_osv') + def test_scan_pyproject_invalid_toml(self, mock_query_osv): + """Test handling of invalid TOML syntax""" + if sys.version_info >= (3, 11): + mock_module = 'tomllib' + else: + mock_module = 'tomli' + + with patch('builtins.open', mock_open()): + with patch(f'securefix.cve.scanner.{mock_module}.load') as mock_load: + mock_load.side_effect = Exception("Invalid TOML") + + with pytest.raises(ValueError, match="Error parsing pyproject.toml"): + scan_pyproject('pyproject.toml') + + @patch('securefix.cve.scanner.query_osv') + def test_scan_pyproject_with_environment_markers(self, mock_query_osv): + """Test scanning dependencies with environment markers""" + mock_query_osv.return_value = ['CVE-2023-1234'] + + pyproject_data = { + 'project': { + 'dependencies': [ + "tomli>=2.0.0; python_version < '3.11'", + ] + } + } - @patch('cve.db.requests.post') - def test_non_200_status_code(self, mock_post): - """Test handling of non-200 status codes""" - mock_response = MagicMock() - mock_response.status_code = 404 - mock_post.return_value = mock_response + with patch('builtins.open', mock_open()): + with patch('securefix.cve.scanner.tomllib.load', return_value=pyproject_data): + findings = scan_pyproject('pyproject.toml') - result = query_osv('unknown', '1.0.0') + # Should parse and check despite environment marker + assert mock_query_osv.call_count == 1 + mock_query_osv.assert_called_with('tomli', '2.0.0') - assert result == [] - @patch('cve.db.requests.post') - def test_request_payload_format(self, mock_post): - """Test that the API request payload is formatted correctly""" - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = {'vulns': []} - mock_post.return_value = mock_response +class TestScanDependenciesAutoDetect: + """Tests for auto-detecting file type""" - query_osv('requests', '2.28.0') + @patch('securefix.cve.scanner.scan_requirements') + def test_autodetect_requirements_txt(self, mock_scan_requirements): + """Test that requirements.txt is detected and scanned""" + mock_scan_requirements.return_value = [] - expected_payload = { - 'package': { - 'name': 'requests', - 'ecosystem': 'PyPI' - }, - 'version': '2.28.0' - } - mock_post.assert_called_once_with( - "https://api.osv.dev/v1/query", - json=expected_payload, - timeout=5 - ) + scan_dependencies('requirements.txt') + mock_scan_requirements.assert_called_once_with('requirements.txt') -class TestScanDependencies: - """Tests for the scan_requirements function""" + @patch('securefix.cve.scanner.scan_pyproject') + def test_autodetect_pyproject_toml(self, mock_scan_pyproject): + """Test that pyproject.toml is detected and scanned""" + mock_scan_pyproject.return_value = [] - @patch('cve.scanner.query_osv') - def test_scan_with_vulnerabilities(self, mock_check_osv): - """Test scanning a requirements file with vulnerabilities""" - requirements_content = "flask==0.12\nrequests==2.28.0\ndjango==2.2.0\n" + scan_dependencies('pyproject.toml') - # Mock different responses for different packages - def check_side_effect(package, version): - if package == 'flask': - return ['CVE-2018-1000656'] - elif package == 'django': - return ['CVE-2019-14234'] - return [] + mock_scan_pyproject.assert_called_once_with('pyproject.toml') - mock_check_osv.side_effect = check_side_effect + @patch('securefix.cve.scanner.scan_requirements') + def test_autodetect_defaults_to_requirements(self, mock_scan_requirements): + """Test that non-.toml files default to requirements scanner""" + mock_scan_requirements.return_value = [] - with patch('builtins.open', mock_open(read_data=requirements_content)): - findings = scan_requirements('requirements.txt') + scan_dependencies('deps.txt') - assert len(findings) == 2 - assert findings[0].package == 'flask' - assert findings[0].version == '0.12' - assert findings[0].cves == ['CVE-2018-1000656'] - assert findings[1].package == 'django' - assert findings[1].version == '2.2.0' - assert findings[1].cves == ['CVE-2019-14234'] + mock_scan_requirements.assert_called_once_with('deps.txt') - @patch('cve.scanner.query_osv') - def test_scan_with_no_vulnerabilities(self, mock_check_osv): - """Test scanning when no vulnerabilities are found""" - requirements_content = "requests==2.28.0\nnumpy==1.24.0\n" - mock_check_osv.return_value = [] + @patch('securefix.cve.scanner.scan_pyproject') + def test_autodetect_case_sensitive_toml(self, mock_scan_pyproject): + """Test that .toml extension detection is case-sensitive""" + mock_scan_pyproject.return_value = [] - with patch('builtins.open', mock_open(read_data=requirements_content)): - findings = scan_requirements('requirements.txt') + scan_dependencies('pyproject.toml') - assert len(findings) == 0 + mock_scan_pyproject.assert_called_once() - @patch('cve.scanner.query_osv') - def test_scan_skips_lines_without_version_pin(self, mock_check_osv): - """Test that lines without == are skipped""" - requirements_content = "flask==0.12\nrequests\n# comment\ndjango>=2.2.0\n" - mock_check_osv.return_value = [] - with patch('builtins.open', mock_open(read_data=requirements_content)): - findings = scan_requirements('requirements.txt') +class TestTomliImport: + """Tests for tomli/tomllib import handling""" - # Should only check flask (the only one with ==) - assert mock_check_osv.call_count == 1 - mock_check_osv.assert_called_with('flask', '0.12') + def test_tomli_not_installed_python_310(self): + """Test that helpful error is raised if tomli not installed on Python < 3.11""" + if sys.version_info >= (3, 11): + pytest.skip("Test only relevant for Python < 3.11") - @patch('cve.scanner.query_osv') - def test_scan_empty_file(self, mock_check_osv): - """Test scanning an empty requirements file""" - with patch('builtins.open', mock_open(read_data="")): - findings = scan_requirements('requirements.txt') + with patch.dict('sys.modules', {'tomli': None}): + with patch('builtins.open', mock_open()): + with pytest.raises(ImportError, match="tomli is required"): + scan_pyproject('pyproject.toml') - assert len(findings) == 0 - mock_check_osv.assert_not_called() - - @patch('cve.scanner.query_osv') - def test_osv_finding_structure(self, mock_check_osv): - """Test that OSVFinding objects are created correctly""" - requirements_content = "flask==0.12\n" - mock_check_osv.return_value = ['CVE-2018-1000656', 'CVE-2019-1010083'] - - with patch('builtins.open', mock_open(read_data=requirements_content)): - findings = scan_requirements('requirements.txt') - - assert len(findings) == 1 - finding = findings[0] - assert isinstance(finding, OSVFinding) - assert finding.package == 'flask' - assert finding.version == '0.12' - assert finding.cves == ['CVE-2018-1000656', 'CVE-2019-1010083'] - - @patch('cve.scanner.query_osv') - def test_scan_with_whitespace(self, mock_check_osv): - """Test that whitespace in requirements is handled""" - requirements_content = " flask==0.12 \n\ndjango==2.2.0\n" - mock_check_osv.return_value = [] - - with patch('builtins.open', mock_open(read_data=requirements_content)): - findings = scan_requirements('requirements.txt') - - # Verify both packages were checked despite whitespace - assert mock_check_osv.call_count == 2 \ No newline at end of file + def test_uses_tomllib_python_311_plus(self): + """Test that tomllib is used on Python 3.11+""" + if sys.version_info < (3, 11): + pytest.skip("Test only relevant for Python >= 3.11") + + # Should not raise ImportError about tomli + with patch('builtins.open', mock_open()): + with patch('tomllib.load') as mock_load: + mock_load.return_value = {'project': {}} + scan_pyproject('pyproject.toml') + mock_load.assert_called_once() \ No newline at end of file diff --git a/tests/test_fix_cache.py b/tests/test_fix_cache.py index e614980..a17d1a9 100644 --- a/tests/test_fix_cache.py +++ b/tests/test_fix_cache.py @@ -8,11 +8,10 @@ - Embedding-based similarity """ import pytest -from unittest.mock import Mock, patch -import numpy as np +from unittest.mock import Mock try: - from remediation.fix_cache import SemanticQueryCache + from securefix.remediation.fix_cache import SemanticQueryCache except ImportError: pytest.skip("fix_cache module not found", allow_module_level=True) diff --git a/tests/test_fix_knowledge_store.py b/tests/test_fix_knowledge_store.py index 33e11ee..a7bfae7 100644 --- a/tests/test_fix_knowledge_store.py +++ b/tests/test_fix_knowledge_store.py @@ -7,11 +7,11 @@ - Integration with vector stores and BM25 """ import pytest -from unittest.mock import Mock, patch, MagicMock +from unittest.mock import Mock, patch from langchain_core.documents import Document try: - from remediation.fix_knowledge_store import DocumentStore + from securefix.remediation.fix_knowledge_store import DocumentStore except ImportError: pytest.skip("fix_knowledge_store module not found", allow_module_level=True) @@ -86,7 +86,7 @@ def test_init_with_bm25(self, mock_vector_store, mock_bm25_index, sample_securit class TestRetrieverCreation: """Test retriever creation and configuration""" - @patch('remediation.fix_knowledge_store.create_hybrid_retrieval_pipeline') + @patch('securefix.remediation.fix_knowledge_store.create_hybrid_retrieval_pipeline') def test_get_retriever_uses_factory(self, mock_factory, doc_store): """Test that get_retriever uses the factory function""" mock_retriever = Mock() @@ -97,7 +97,7 @@ def test_get_retriever_uses_factory(self, mock_factory, doc_store): mock_factory.assert_called_once() assert retriever == mock_retriever - @patch('remediation.fix_knowledge_store.create_hybrid_retrieval_pipeline') + @patch('securefix.remediation.fix_knowledge_store.create_hybrid_retrieval_pipeline') def test_get_retriever_caches_instance(self, mock_factory, doc_store): """Test that retriever is cached after first build""" mock_retriever = Mock() @@ -110,7 +110,7 @@ def test_get_retriever_caches_instance(self, mock_factory, doc_store): assert mock_factory.call_count == 1 assert retriever1 is retriever2 - @patch('remediation.fix_knowledge_store.create_hybrid_retrieval_pipeline') + @patch('securefix.remediation.fix_knowledge_store.create_hybrid_retrieval_pipeline') def test_get_retriever_passes_bm25_components(self, mock_factory, doc_store_with_bm25): """Test that BM25 components are passed to factory""" doc_store_with_bm25.get_retriever() @@ -119,7 +119,7 @@ def test_get_retriever_passes_bm25_components(self, mock_factory, doc_store_with assert call_kwargs['bm25_index'] == doc_store_with_bm25.bm25_index assert call_kwargs['bm25_chunks'] == doc_store_with_bm25.bm25_chunks - @patch('remediation.fix_knowledge_store.create_hybrid_retrieval_pipeline') + @patch('securefix.remediation.fix_knowledge_store.create_hybrid_retrieval_pipeline') def test_get_retriever_enables_reranking(self, mock_factory, doc_store): """Test that reranking is enabled by default""" doc_store.get_retriever() @@ -318,7 +318,7 @@ def test_retriever_cached_state(self, doc_store): """Test that retriever is cached in _retriever attribute""" assert doc_store._retriever is None - with patch('remediation.fix_knowledge_store.create_hybrid_retrieval_pipeline') as mock_factory: + with patch('securefix.remediation.fix_knowledge_store.create_hybrid_retrieval_pipeline') as mock_factory: mock_retriever = Mock() mock_factory.return_value = mock_retriever diff --git a/tests/test_llm_factory.py b/tests/test_llm_factory.py index 00cf8eb..8930e5c 100644 --- a/tests/test_llm_factory.py +++ b/tests/test_llm_factory.py @@ -1,17 +1,12 @@ import pytest -from unittest.mock import Mock, patch, MagicMock +from unittest.mock import Mock, patch import sys import os +from securefix.remediation.llm_factory import LLMFactory, GoogleGenAIConfig, OllamaConfig, check_ollama_available # Add parent directory to path so we can import modules from remediation/ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) -from remediation.llm_factory import ( - LLMFactory, - GoogleGenAIConfig, - OllamaConfig, - check_ollama_available, -) class TestLLMFactory: diff --git a/tests/test_remediation_engine.py b/tests/test_remediation_engine.py index a1b917b..5a2924b 100644 --- a/tests/test_remediation_engine.py +++ b/tests/test_remediation_engine.py @@ -8,14 +8,14 @@ - Integration with LLM and retriever """ import pytest -from unittest.mock import Mock, patch, MagicMock +from unittest.mock import patch from langchain_core.documents import Document - -from remediation.remediation_engine import RemediationEngine, RemediationEngineError +from securefix.remediation.remediation_engine import RemediationEngine, RemediationEngineError # Mock NLTK at module import time if it doesn't exist -import remediation.remediation_engine as remediation_engine +from securefix.remediation import remediation_engine + if not hasattr(remediation_engine, 'word_tokenize'): remediation_engine.word_tokenize = lambda x: x.split() @@ -376,7 +376,7 @@ class TestTokenStreamCallbackHandler: def test_callback_handler_calls_function(self): """Test that callback handler invokes the provided function""" - from remediation.remediation_engine import TokenStreamCallbackHandler + from securefix.remediation.remediation_engine import TokenStreamCallbackHandler tokens_received = [] @@ -393,7 +393,7 @@ def callback(token): def test_callback_handler_with_kwargs(self): """Test that callback handler handles additional kwargs""" - from remediation.remediation_engine import TokenStreamCallbackHandler + from securefix.remediation.remediation_engine import TokenStreamCallbackHandler tokens_received = [] diff --git a/tests/test_securefix.py b/tests/test_securefix.py index 3ecf7a2..f65f10a 100644 --- a/tests/test_securefix.py +++ b/tests/test_securefix.py @@ -4,7 +4,7 @@ from click.testing import CliRunner from pathlib import Path from unittest.mock import patch -from securefix import cli +from securefix.cli import cli @pytest.fixture @@ -57,7 +57,7 @@ def test_scan_file_with_vulnerabilities(self, runner, temp_vulnerable_file, tmp_ output_path = tmp_path / "report.json" # Mock config finder to avoid picking up any local bandit config - with patch('sast.bandit_scanner._find_bandit_config', return_value=None): + with patch('securefix.sast.bandit_scanner._find_bandit_config', return_value=None): result = runner.invoke(cli, [ 'scan', str(temp_vulnerable_file), @@ -140,7 +140,7 @@ def test_scan_with_dependencies(self, runner, temp_vulnerable_file, temp_require output_path = tmp_path / "report.json" # Mock config finder to avoid picking up any local bandit config - with patch('sast.bandit_scanner._find_bandit_config', return_value=None): + with patch('securefix.sast.bandit_scanner._find_bandit_config', return_value=None): result = runner.invoke(cli, [ 'scan', str(temp_vulnerable_file), diff --git a/tests/test_vulnerability_retriever.py b/tests/test_vulnerability_retriever.py index 450c045..efdcafa 100644 --- a/tests/test_vulnerability_retriever.py +++ b/tests/test_vulnerability_retriever.py @@ -10,16 +10,9 @@ """ import pytest import numpy as np -from unittest.mock import Mock, patch, MagicMock +from unittest.mock import Mock, patch from langchain_core.documents import Document - -from remediation.vulnerability_retriever import ( - HybridRetriever, - RerankerCompressor, - get_reranker, - create_hybrid_retrieval_pipeline, - _RERANKER_CACHE -) +from securefix.remediation.vulnerability_retriever import HybridRetriever, RerankerCompressor, get_reranker, create_hybrid_retrieval_pipeline, _RERANKER_CACHE @pytest.fixture @@ -280,7 +273,7 @@ def teardown_method(self): def test_get_reranker_creates_instance(self): """Test that get_reranker creates a new instance""" - with patch('remediation.vulnerability_retriever.RerankerCompressor') as mock_compressor: + with patch('securefix.remediation.vulnerability_retriever.RerankerCompressor') as mock_compressor: mock_instance = Mock() mock_compressor.return_value = mock_instance @@ -291,7 +284,7 @@ def test_get_reranker_creates_instance(self): def test_get_reranker_caches_instance(self): """Test that get_reranker caches instances""" - with patch('remediation.vulnerability_retriever.RerankerCompressor') as mock_compressor: + with patch('securefix.remediation.vulnerability_retriever.RerankerCompressor') as mock_compressor: mock_instance = Mock() mock_compressor.return_value = mock_instance @@ -305,7 +298,7 @@ def test_get_reranker_caches_instance(self): def test_get_reranker_different_params_different_cache(self): """Test that different parameters create different cache entries""" - with patch('remediation.vulnerability_retriever.RerankerCompressor') as mock_compressor: + with patch('securefix.remediation.vulnerability_retriever.RerankerCompressor') as mock_compressor: mock_compressor.side_effect = [Mock(), Mock()] reranker1 = get_reranker(model_name="model1", top_k=3) From 18622dd390d5b37bae298ba72459335e0a4318b4 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 30 Oct 2025 17:01:12 -0400 Subject: [PATCH 2/3] fix cli integration --- securefix/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/securefix/cli.py b/securefix/cli.py index 00a3514..f139d4b 100644 --- a/securefix/cli.py +++ b/securefix/cli.py @@ -46,7 +46,7 @@ def scan(target, dependencies, output, severity, confidence): cve_findings = [] if dependencies: click.echo(f"Scanning dependencies in {dependencies}...") - cve_findings = cve_scanner.scan_requirements(dependencies) + cve_findings = cve_scanner.scan_dependencies(dependencies) click.echo(f"Found {len(cve_findings)} vulnerable dependencies") # Create report From 2bc868663429ef5f82dfaf9ca55153b27bc9fca0 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 30 Oct 2025 17:19:35 -0400 Subject: [PATCH 3/3] test fix --- securefix/cve/scanner.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/securefix/cve/scanner.py b/securefix/cve/scanner.py index 1e6b446..d63958b 100644 --- a/securefix/cve/scanner.py +++ b/securefix/cve/scanner.py @@ -118,19 +118,17 @@ def _extract_version_from_specifier(specifier_set): 2. Minimum version (>=, >) 3. Maximum version (<, <=) """ + # First pass: look for exact version (highest priority) for spec in specifier_set: - operator = spec.operator - version = spec.version - - # Exact match - use this version - if operator == "==": - return version + if spec.operator == "==": + return spec.version - # Greater than or equal - use minimum version - if operator in (">=", ">"): - return version + # Second pass: look for minimum version + for spec in specifier_set: + if spec.operator in (">=", ">"): + return spec.version - # If no >= or ==, look for < or <= + # Third pass: look for maximum version for spec in specifier_set: if spec.operator in ("<", "<="): return spec.version