Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions gitshield/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""GitShield CLI — Prevent accidental secret commits."""

import dataclasses
import re
import sys
from pathlib import Path
Expand All @@ -9,7 +10,8 @@
from . import __version__
from .config import build_custom_patterns, filter_findings, load_config, load_ignore_list, find_git_root
from .formatter import print_findings, print_json, print_blocked_message, colorize, Colors
from .scanner import scan_path, ScannerError
from .models import ScannerError
from .scanner import scan_path


@click.group()
Expand Down Expand Up @@ -49,7 +51,16 @@ def scan(path: str, staged: bool, no_git: bool, as_json: bool, sarif: bool, quie
# Output
if sarif:
from .formatter import print_sarif
print_sarif(findings)
# SARIF requires relative URIs so GitHub Code Scanning can map findings.
scan_root = Path(path).resolve()
sarif_findings = []
for f in findings:
try:
rel = str(Path(f.file).relative_to(scan_root))
except ValueError:
rel = f.file
sarif_findings.append(dataclasses.replace(f, file=rel))
print_sarif(sarif_findings)
elif as_json:
print_json(findings)
else:
Expand Down
53 changes: 48 additions & 5 deletions gitshield/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class GitShieldConfig:
entropy_threshold: float = 4.5
scan_tests: bool = False
allowlist_paths: List[str] = field(default_factory=list)
allowlist_rules: List[str] = field(default_factory=list)
allowlist_rules: Set[str] = field(default_factory=set)
allowlist_fingerprints: Set[str] = field(default_factory=set)
custom_patterns: List[Dict[str, Any]] = field(default_factory=list)

Expand Down Expand Up @@ -122,7 +122,7 @@ def load_ignore_list(path: Path) -> Set[str]:
return set()

ignores = set()
with open(ignore_file) as f:
with open(ignore_file, encoding="utf-8", errors="replace") as f:
for line in f:
line = line.strip()
if line and not line.startswith("#"):
Expand Down Expand Up @@ -173,11 +173,16 @@ def load_config(path: Path) -> GitShieldConfig:
else:
fingerprints = set()

try:
entropy_threshold = float(scan.get("entropy_threshold", 4.5))
except (ValueError, TypeError):
entropy_threshold = 4.5

return GitShieldConfig(
entropy_threshold=float(scan.get("entropy_threshold", 4.5)),
entropy_threshold=entropy_threshold,
scan_tests=bool(scan.get("scan_tests", False)),
allowlist_paths=list(allowlist.get("paths", [])),
allowlist_rules=list(allowlist.get("rules", [])),
allowlist_rules=set(allowlist.get("rules", [])),
allowlist_fingerprints=fingerprints,
custom_patterns=list(data.get("custom_patterns", [])),
)
Expand Down Expand Up @@ -270,6 +275,40 @@ def filter_findings(
# Custom pattern builder
# ---------------------------------------------------------------------------

_REDOS_TEST_STRING = "a" * 100


def _regex_is_safe(compiled_re) -> bool:
"""Return True if the regex completes on a benign test string within 1 second.

Protects against catastrophic backtracking (ReDoS) in custom patterns from
.gitshield.toml. Tries direct execution first (fast path); falls back to a
background thread only if the direct run exceeds 50 ms.
"""
import threading
import time

# Fast path: run directly and measure wall time. Most safe patterns complete
# in microseconds, so we avoid thread-creation overhead (~0.5–1 ms each).
start = time.monotonic()
compiled_re.search(_REDOS_TEST_STRING)
elapsed = time.monotonic() - start
if elapsed < 0.05:
return True

# Slow path: the pattern took >50 ms on a short string — suspicious.
# Re-check with a strict 1-second timeout via a background thread.
finished = threading.Event()

def _run():
compiled_re.search(_REDOS_TEST_STRING)
finished.set()

t = threading.Thread(target=_run, daemon=True)
t.start()
return finished.wait(timeout=1.0)


def build_custom_patterns(config: "GitShieldConfig") -> List[Pattern]:
"""Convert config.custom_patterns dicts into Pattern objects.

Expand Down Expand Up @@ -303,6 +342,10 @@ def build_custom_patterns(config: "GitShieldConfig") -> List[Pattern]:
print(f"gitshield: custom pattern '{pattern_id}' has invalid regex: {exc}", file=sys.stderr)
continue

if not _regex_is_safe(compiled):
print(f"gitshield: custom pattern '{pattern_id}' timed out (possible ReDoS), skipping", file=sys.stderr)
continue

try:
built.append(Pattern(
id=pattern_id,
Expand Down Expand Up @@ -336,7 +379,7 @@ def create_ignore_file(path: Path, findings: List[Finding]) -> Path:
lines.append(f.fingerprint)
lines.append("")

with open(ignore_file, "w") as file:
with open(ignore_file, "w", encoding="utf-8") as file:
file.write("\n".join(lines))

return ignore_file
52 changes: 42 additions & 10 deletions gitshield/db.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""SQLite database for tracking scanned repos and notifications."""

import atexit
import os
import sqlite3
import threading
from datetime import datetime
from pathlib import Path
from typing import List, Optional, Set
Expand All @@ -12,6 +14,7 @@

# Module-level singleton connection — initialized on first use.
_conn: Optional[sqlite3.Connection] = None
_lock = threading.Lock()


def _close_connection() -> None:
Expand All @@ -29,10 +32,14 @@ def get_connection() -> sqlite3.Connection:
"""Return the module-level DB connection, creating it on first call."""
global _conn
if _conn is None:
DB_DIR.mkdir(parents=True, exist_ok=True)
_conn = sqlite3.connect(DB_PATH)
_conn.row_factory = sqlite3.Row
_init_tables(_conn)
with _lock:
if _conn is None:
DB_DIR.mkdir(parents=True, exist_ok=True)
DB_DIR.chmod(0o700)
_conn = sqlite3.connect(DB_PATH, check_same_thread=False)
os.chmod(DB_PATH, 0o600)
_conn.row_factory = sqlite3.Row
_init_tables(_conn)
return _conn


Expand Down Expand Up @@ -116,17 +123,42 @@ def mark_notified(
conn.commit()


def mark_notified_batch(
repo_url: str,
fingerprints: List[str],
email: Optional[str] = None,
method: str = "email",
) -> None:
"""Record that we notified about multiple findings in a single transaction."""
if not fingerprints:
return
conn = get_connection()
now = datetime.now().isoformat()
conn.executemany("""
INSERT OR IGNORE INTO notifications
(repo_url, email, fingerprint, notified_at, method)
VALUES (?, ?, ?, ?, ?)
""", [(repo_url, email, fp, now, method) for fp in fingerprints])
conn.commit()


def get_notified_fingerprints(repo_url: str, fingerprints: List[str]) -> Set[str]:
"""Return the subset of *fingerprints* that have already been notified."""
if not fingerprints:
return set()
conn = get_connection()
placeholders = ",".join("?" * len(fingerprints))
cursor = conn.execute(
f"SELECT fingerprint FROM notifications WHERE repo_url = ? AND fingerprint IN ({placeholders})",
(repo_url, *fingerprints),
)
return {row["fingerprint"] for row in cursor.fetchall()}
# Batch into chunks of 500 to stay well under SQLite's SQLITE_MAX_VARIABLE_NUMBER limit.
_CHUNK_SIZE = 500
result: Set[str] = set()
for i in range(0, len(fingerprints), _CHUNK_SIZE):
chunk = fingerprints[i:i + _CHUNK_SIZE]
placeholders = ",".join("?" * len(chunk))
cursor = conn.execute(
f"SELECT fingerprint FROM notifications WHERE repo_url = ? AND fingerprint IN ({placeholders})",
(repo_url, *chunk),
)
result.update(row["fingerprint"] for row in cursor.fetchall())
return result


def get_stats() -> dict:
Expand Down
86 changes: 67 additions & 19 deletions gitshield/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import fnmatch
import itertools
import os
import re
import subprocess
Expand Down Expand Up @@ -57,14 +58,9 @@
# ---------------------------------------------------------------------------


def _is_binary_file(filepath: Path) -> bool:
"""Return True if *filepath* looks like a binary file (null byte in first 8 KB)."""
try:
with open(filepath, "rb") as fh:
chunk = fh.read(8192)
return b"\x00" in chunk
except (OSError, IOError):
return True # unreadable — treat as binary
def _has_binary_extension(path: Path) -> bool:
"""Return True if *path* has a binary file extension."""
return path.suffix.lower() in _BINARY_EXTENSIONS


def _should_skip_path(path: Path) -> bool:
Expand All @@ -74,9 +70,7 @@ def _should_skip_path(path: Path) -> bool:
if part in _SKIP_DIRS:
return True
# Check binary extensions.
if path.suffix.lower() in _BINARY_EXTENSIONS:
return True
return False
return _has_binary_extension(path)


def _parse_gitignore(root: Path) -> List[str]:
Expand Down Expand Up @@ -113,16 +107,19 @@ def _compile_gitignore_patterns(patterns: List[str]) -> List[tuple]:

def _matches_gitignore(rel_path: str, ignore_patterns: List[tuple]) -> bool:
"""Return True if *rel_path* matches any pre-compiled gitignore pattern."""
path_obj = Path(rel_path)
parts = path_obj.parts
name = path_obj.name
for is_dir, compiled_re in ignore_patterns:
if is_dir:
# Directory-only pattern: match against path components.
if any(compiled_re.fullmatch(part) for part in Path(rel_path).parts):
if any(compiled_re.fullmatch(part) for part in parts):
return True
else:
# Match against full relative path and also the basename.
if compiled_re.fullmatch(rel_path):
return True
if compiled_re.fullmatch(Path(rel_path).name):
if compiled_re.fullmatch(name):
return True
return False

Expand Down Expand Up @@ -155,7 +152,7 @@ def scan_text(
"""
findings: List[Finding] = []
lines = text.splitlines()
all_patterns = list(PATTERNS) + list(extra_patterns or [])
all_patterns = PATTERNS if not extra_patterns else itertools.chain(PATTERNS, extra_patterns)

for idx, line in enumerate(lines, start=1):
# Honour inline ignore directives.
Expand Down Expand Up @@ -297,7 +294,12 @@ def scan_directory(

# ---- staged-only mode: delegate to git for the file list ----
if staged_only:
return _scan_staged(root)
return _scan_staged(
root,
config_threshold=config_threshold,
extra_patterns=extra_patterns,
scan_tests=scan_tests,
)

# ---- full tree walk ----
ignore_patterns: List[tuple] = []
Expand All @@ -307,14 +309,23 @@ def scan_directory(

findings: List[Finding] = []

for dirpath, dirnames, filenames in os.walk(root):
for dirpath, dirnames, filenames in os.walk(root, followlinks=False):
# Prune skip directories in-place to prevent descending into them.
dirnames[:] = [d for d in dirnames if d not in _SKIP_DIRS]

for filename in filenames:
file_path = Path(dirpath) / filename

if _should_skip_path(file_path):
# Skip symlinks — they may point outside the repository root.
if file_path.is_symlink():
continue

# Ensure the file is within the root (guards against unusual filesystems).
if not file_path.resolve().is_relative_to(root):
continue

# Directories are already pruned above; only check binary extension here.
if _has_binary_extension(file_path):
continue

# Skip test files when scan_tests is disabled.
Expand Down Expand Up @@ -372,7 +383,12 @@ def scan_content(
# Internal: staged-file scanning
# ---------------------------------------------------------------------------

def _scan_staged(root: Path) -> List[Finding]:
def _scan_staged(
root: Path,
config_threshold: Optional[float] = None,
extra_patterns: Optional[List] = None,
scan_tests: bool = True,
) -> List[Finding]:
"""Scan only files staged in git inside *root*."""
try:
result = subprocess.run(
Expand Down Expand Up @@ -400,6 +416,38 @@ def _scan_staged(root: Path) -> List[Finding]:
continue
if _should_skip_path(file_path):
continue
findings.extend(scan_file(file_path))
if not scan_tests and _is_test_file(file_path.name):
continue

# Read the staged (index) version, not the working-tree copy.
# This prevents bypass: stage secret → edit working tree to remove it.
try:
show_result = subprocess.run(
["git", "show", f":{rel_name}"],
capture_output=True,
cwd=str(root),
timeout=30,
)
except subprocess.TimeoutExpired:
continue
except (OSError, FileNotFoundError):
continue

if show_result.returncode != 0:
continue

try:
staged_content = show_result.stdout.decode("utf-8", errors="replace")
except (UnicodeDecodeError, AttributeError):
continue

findings.extend(
scan_text(
staged_content,
filename=rel_name,
config_threshold=config_threshold,
extra_patterns=extra_patterns,
)
)

return findings
7 changes: 6 additions & 1 deletion gitshield/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,12 @@ def print_findings(findings: List[Finding], quiet: bool = False) -> None:


def format_findings_json(findings: List[Finding]) -> str:
"""Return findings as a JSON string."""
"""Return findings as a JSON string.

Note: the ``file`` field reflects whatever path was passed to the scanner.
When scanning with absolute paths the output will contain absolute paths.
Use the CLI's SARIF output (``--sarif``) for paths normalised to repo root.
"""
data = [
{
"file": f.file,
Expand Down
Loading
Loading