Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 13 additions & 15 deletions src/extension_shield/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@
# Initialize logger
logger = logging.getLogger(__name__)


def _parse_trusted_proxy_hosts() -> list[str]:
"""Return the explicit proxy hosts allowed to send forwarded headers."""
raw_hosts = os.getenv("TRUSTED_PROXY_HOSTS", "").strip()
if raw_hosts:
return [host.strip() for host in raw_hosts.split(",") if host.strip()]
return ["127.0.0.1", "localhost", "::1"]
Comment on lines +66 to +71
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_parse_trusted_proxy_hosts() currently returns raw comma-separated tokens without validating format. Since ProxyHeadersMiddleware.trusted_hosts expects a very specific set of values (e.g., IPs/CIDRs/* depending on Uvicorn), a typo or hostname here can silently misconfigure proxy trust (or fail at runtime). Consider validating each entry on startup and raising a clear ValueError (or logging an explicit warning) when an entry is not a valid trusted-host spec, and document the expected format in the docstring/env var name.

Copilot uses AI. Check for mistakes.

# Import safe JSON utilities from shared module
from extension_shield.utils.json_encoder import (
safe_json_dumps,
Expand Down Expand Up @@ -361,8 +369,9 @@ async def add_security_headers(request: Request, call_next):
print(f"✅ CSP: Production mode detected (STATIC_DIR={STATIC_DIR}, index.html exists)")
app.add_middleware(CSPMiddleware, is_dev=_is_dev)

# Trust X-Forwarded-Proto / X-Forwarded-For from Railway/Cloudflare so request.url.scheme is correct
app.add_middleware(ProxyHeadersMiddleware, trusted_hosts="*")
# Trust forwarded headers only from explicitly allowed proxy hosts.
# Set TRUSTED_PROXY_HOSTS to your actual reverse proxy / CDN hop(s).
app.add_middleware(ProxyHeadersMiddleware, trusted_hosts=_parse_trusted_proxy_hosts())
Comment on lines +372 to +374
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Security logic still directly trusts the raw X-Forwarded-Proto header in add_security_headers (it checks the header when request.url.scheme is not https). Even with ProxyHeadersMiddleware restricted, an untrusted client can still send X-Forwarded-Proto: https and trigger HSTS / “effective HTTPS” behavior. To fully eliminate spoofing, remove the direct header fallback and rely on the scheme as rewritten by ProxyHeadersMiddleware (or only consult forwarded headers when the immediate peer is trusted).

Copilot uses AI. Check for mistakes.
Comment on lines +372 to +374
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are existing API tests (e.g., CORS + admin endpoint tests) but nothing that exercises the new proxy-trust behavior. Please add tests that prove untrusted requests cannot influence request.url.scheme / client IP via X-Forwarded-*, and that trusted proxy hops can. This will help prevent regressions where forwarded headers are accidentally re-trusted broadly again.

Copilot uses AI. Check for mistakes.

# In-memory state lives in shared.py; import references here so existing
# code in this file (and tests) can continue using module-level names.
Expand Down Expand Up @@ -408,20 +417,9 @@ def _get_client_ip(request: Request) -> str:
"""
Get the client's IP address for rate limiting anonymous users.

Handles proxied requests via X-Forwarded-For and X-Real-IP headers.
Falls back to client host if no headers present.
Relies on ProxyHeadersMiddleware to rewrite request.client only when the
request came from a trusted proxy host.
"""
# Check X-Forwarded-For header (from reverse proxy/load balancer)
x_forwarded_for = request.headers.get("x-forwarded-for")
if x_forwarded_for:
# Take the first IP (original client)
return x_forwarded_for.split(",")[0].strip()

# Check X-Real-IP header (from nginx)
x_real_ip = request.headers.get("x-real-ip")
if x_real_ip:
return x_real_ip.strip()

# Fall back to direct client IP
if request.client:
return request.client.host
Expand Down
Loading