Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pia/dependencytrack.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
"""DependencyTrack API client."""

import logging

import requests

from .models import DependencyTrackUploadPayload

logger = logging.getLogger(__name__)


class DependencyTrackError(Exception):
"""Raised when DependencyTrack API request fails."""
Expand All @@ -23,11 +27,13 @@ def upload_sbom(
}

try:
logger.info(f"Uploading SBOM to DependencyTrack at {url}")
response = requests.put(
url,
json=payload.to_dict(),
headers=headers,
)
logger.info(f"DependencyTrack responded with status {response.status_code}")
return response

except requests.RequestException as e:
Expand Down
25 changes: 20 additions & 5 deletions pia/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
@asynccontextmanager
async def load_project_settings_on_startup(app: FastAPI):
app.state.projects = Projects.from_yaml_file(settings.projects_path)
logger.info(f"Loaded projects from {settings.projects_path}")
yield


Expand Down Expand Up @@ -76,11 +75,15 @@ async def upload_sbom(
# Handle Auth Header
projects: Projects = request.app.state.projects

logger.info("Received SBOM upload request")

# Extract Bearer token from Authorization header
if not authorization.startswith("Bearer "):
_401("Invalid Authorization header format")
token = authorization[7:] # Remove "Bearer " prefix

logger.info("Bearer token extracted from Authorization header")

# Extract issuer from unverified token
try:
unverified_claims = jwt.decode(
Expand All @@ -92,16 +95,25 @@ async def upload_sbom(
logger.warning(f"Token decode failed: {e}")
_401("Invalid token")

logger.info(f"Unverified issuer extracted: {unverified_issuer}")

# Check if issuer exists in any project configuration to fail early
#
# NOTE: This is an expensive operation (iterates over all projects) for a
# completely unauthenticated request. Consider to ...
# - make less expensive (optimize with db), or
# - match against issuer constants (full for GitHub, prefix-only for Jenkins)
logger.info(
f"Checking if issuer '{unverified_issuer}' is allowed "
f"across {len(projects.root)} project(s)"
)
if not projects.has_issuer(unverified_issuer):
logger.warning(f"Issuer {unverified_issuer} not allowed")
_401("Issuer not allowed")

logger.info(
f"Issuer '{unverified_issuer}' is allowed, proceeding with token verification"
)
# Full token verification
try:
verified_claims = oidc.verify_token(
Expand All @@ -113,6 +125,8 @@ async def upload_sbom(
logger.warning(f"Token verification failed: {e}")
_401("Token verification failed")

logger.info("Token signature verified successfully")

# Find project by matching verified token claims
# NOTE: Returns first match
project = projects.find_project_by_claims(verified_claims)
Expand All @@ -128,6 +142,11 @@ async def upload_sbom(
# ==========================================================================
# Handle Payload

logger.info(
f"Preparing DependencyTrack payload for {payload.product_name} "
f"{payload.product_version}"
)

# Create DependencyTrack payload
dt_payload = DependencyTrackUploadPayload(
project_name=payload.product_name,
Expand All @@ -144,10 +163,6 @@ async def upload_sbom(
settings.dependency_track_api_key,
dt_payload,
)
logger.info(
f"Uploaded SBOM for {project.project_id}/{payload.product_name} "
f"to DependencyTrack (status: {dt_response.status_code})"
)

# Relay DependencyTrack response
response = Response(
Expand Down
31 changes: 27 additions & 4 deletions pia/models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""Data models with validation and authentication logic."""

import logging
from typing import Annotated, Any

import yaml
from pydantic import BaseModel, ConfigDict, Field, HttpUrl, RootModel, UrlConstraints

logger = logging.getLogger(__name__)

# `preserve_empty_path=True` tells pydantic to not add any trailing slashes,
# to avoid surprising results in `Project.match_issuer`.
HttpsUrl = Annotated[
Expand Down Expand Up @@ -61,7 +64,8 @@ class Projects(RootModel):

def has_issuer(self, issuer: str) -> bool:
"""Check if any project has the given issuer."""
return any(project.match_issuer(issuer) for project in self.root)
found = any(project.match_issuer(issuer) for project in self.root)
return found

def find_project_by_claims(self, token_claims: dict[str, Any]) -> Project | None:
"""Find project by matching verified token claims.
Expand All @@ -70,9 +74,26 @@ def find_project_by_claims(self, token_claims: dict[str, Any]) -> Project | None
A project matches if issuer matches AND all required_claims match.
"""
issuer = token_claims["iss"]
logger.info(
f"Searching for project matching issuer '{issuer}' and token claims"
)
for project in self.root:
if project.match_issuer(issuer) and project.match_claims(token_claims):
return project
if not project.match_issuer(issuer):
logger.info(
f"Project '{project.project_id}': issuer mismatch, skipping"
)
continue
if not project.match_claims(token_claims):
logger.info(
f"Project '{project.project_id}': issuer matches"
" but claims mismatch, skipping"
)
continue
logger.info(
f"Project '{project.project_id}': issuer and claims match, "
f"claims are: {project.required_claims}"
)
return project
return None

@classmethod
Expand All @@ -81,7 +102,9 @@ def from_yaml_file(cls, path: str) -> "Projects":
with open(path) as f:
config = yaml.safe_load(f)

return cls.model_validate(config)
projects = cls.model_validate(config)
logger.info(f"Loaded {len(projects.root)} project(s) from {path}")
return projects


class PiaUploadPayload(BaseModel):
Expand Down
16 changes: 16 additions & 0 deletions pia/oidc.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""OIDC token validation and signature verification."""

import logging
from typing import Any

import jwt
import requests

logger = logging.getLogger(__name__)


class TokenVerificationError(Exception):
"""Raised when token verification fails."""
Expand All @@ -18,13 +21,17 @@ def verify_token(
"""Verify JWT token signature using OIDC discovery and return claims.
Raises TokenVerificationError, if verification fails
"""
logger.info(f"Starting token verification for issuer: {issuer}")

# 1. Request OIDC configuration from issuer
config_url = f"{issuer}/.well-known/openid-configuration"

try:
logger.info(f"Fetching OIDC configuration from {config_url}")
response = requests.get(config_url, timeout=10)
response.raise_for_status()
oidc_config = response.json()
logger.info("OIDC configuration fetched successfully")

except requests.RequestException as e:
raise TokenVerificationError(
Expand All @@ -36,12 +43,20 @@ def verify_token(
if not jwks_uri:
raise TokenVerificationError("OIDC configuration missing 'jwks_uri'")

logger.info(f"JWKS URI: {jwks_uri}")

try:
# 3. Requests public keys from issuer
logger.info("Fetching signing key from JWKS endpoint")
jwks_client = jwt.PyJWKClient(jwks_uri)
signing_key = jwks_client.get_signing_key_from_jwt(token)
logger.info("Signing key retrieved successfully")

# 4. Verify token signature and content
logger.info(
f"Verifying token signature and claims "
f"(expected audience: {expected_audience})"
)
claims = jwt.decode(
token,
signing_key.key,
Expand All @@ -56,6 +71,7 @@ def verify_token(
),
)

logger.info("Token decoded and verified successfully")
return claims

except Exception as e:
Expand Down
8 changes: 5 additions & 3 deletions tests/test_dependencytrack.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Tests for dependencytrack module."""

from unittest.mock import patch
from unittest.mock import MagicMock, patch

import pytest
import requests
Expand Down Expand Up @@ -28,11 +28,13 @@ class TestUploadSBOM:
@patch("pia.dependencytrack.requests.put")
def test_upload(self, mock_put, dt_payload):
"""Test request and response."""
mock_put.return_value = "mock_response"
mock_response = MagicMock()
mock_response.status_code = 200
mock_put.return_value = mock_response
result = upload_sbom(TEST_URL, TEST_API_KEY, dt_payload)

# Assert result is request response
assert result == "mock_response"
assert result == mock_response

# Assert request was made correctly
mock_put.assert_called_once_with(
Expand Down
Loading