Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added .coverage
Binary file not shown.
9 changes: 4 additions & 5 deletions .eslintrc
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
],
"extends": [
"eslint:recommended",
"plugin:@typescript-eslint/recommended",
"plugin:react-hooks/recommended"
"plugin:@typescript-eslint/recommended"
],
"rules": {
"eqeqeq": [
Expand Down Expand Up @@ -53,9 +52,6 @@
],
"@typescript-eslint/no-non-null-assertion": "off",
"@typescript-eslint/no-explicit-any": "off",
"react-hooks/rules-of-hooks": "error",
"react-hooks/exhaustive-deps": "error",
"react/prop-types": "off",
"spellcheck/spell-checker": [
"error",
{
Expand Down Expand Up @@ -216,6 +212,7 @@
"Inservice",
"instanceof",
"interactable",
"invoker",
"ipinsights",
"iss",
"iter",
Expand Down Expand Up @@ -250,6 +247,7 @@
"mlogloss",
"mlp",
"mls",
"minify",
"mlspace",
"msd",
"mse",
Expand All @@ -260,6 +258,7 @@
"ndcg",
"nlists",
"ngrams",
"nodejs",
"nms",
"nonpositive",
"nthread",
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/code.deploy.demo.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ jobs:
working-directory: ./frontend
run: |
npm install
- name: Build frontend
working-directory: ./frontend
run: |
npm run build
- name: Install CDK dependencies
run: |
npm install
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/code.deploy.development.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ jobs:
working-directory: ./frontend
run: |
npm install
- name: Build frontend
working-directory: ./frontend
run: |
npm run build
- name: Install CDK dependencies
run: |
npm install
Expand Down
8 changes: 8 additions & 0 deletions .github/workflows/code.test-and-lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ jobs:
- name: Install dependencies
run: |
npm install
- name: Install Cypress dependencies
working-directory: ./cypress
run: |
npm install
- name: Lint
working-directory: ./cypress
run: |
Expand All @@ -131,6 +135,10 @@ jobs:
working-directory: ./
run: |
npm install
- name: Install dependencies Cypress
working-directory: ./cypress
run: |
npm install
- uses: pre-commit/action@v3.0.1
send_final_slack_notification:
name: Send Final Slack Notification
Expand Down
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ npm-debug.log*
.attach_pid*
cdk.context.json
.idea
.cursor
memory-bank

# Cypress Tests
cypress/cypress/screenshots/

# Environment-specific config
lib/config.json
8 changes: 6 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,15 @@ repos:
- id: flake8
args: ["--count", "--select=E9,F63,F7,F82", "--show-source", "--statistics", "--max-line-length=127"]

- repo: https://github.com/pre-commit/mirrors-eslint
rev: v8.56.0
# Run package-local ESLint via scripts/precommit-eslint.sh: root (lib/), frontend/, cypress/.
- repo: local
hooks:
- id: eslint
name: eslint
language: system
entry: scripts/precommit-eslint.sh
files: \.[jt]sx?$ # *.js, *.jsx, *.ts and *.tsx
exclude: ^frontend/docs/
types: [file]

- repo: https://github.com/Lucas-C/pre-commit-hooks
Expand Down
48 changes: 43 additions & 5 deletions backend/src/ml_space_lambda/auth/lambda_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def _get_auth_config() -> Dict[str, str]:
"token_encryption_key_secret_name": os.environ.get("AUTH_TOKEN_ENCRYPTION_KEY_SECRET_NAME", ""),
"session_table_name": os.environ.get("AUTH_SESSION_TABLE_NAME", ""),
"sync_domains": os.environ.get("AUTH_SYNC_DOMAINS", ""),
"ALLOW_LOCALHOST": os.environ.get("ALLOW_LOCALHOST", "false").lower() == "true",
}

# Validate IdP type first
Expand All @@ -112,6 +113,9 @@ def _get_auth_config() -> Dict[str, str]:
if not config["session_table_name"]:
raise Exception("AUTH_SESSION_TABLE_NAME environment variable is required")

if config["ALLOW_LOCALHOST"]:
logger.warning("ALLOW_LOCALHOST enabled; please disable for production deployments")

return config


Expand Down Expand Up @@ -433,7 +437,7 @@ def _get_auth_path(event: Dict) -> str:
return "/auth"


def _validate_redirect_url(redirect_url: str, host_header: str) -> bool:
def _validate_redirect_url(redirect_url: str, host_header: str, allow_localhost: bool = False) -> bool:
"""
Validate that redirect URL is safe and belongs to the same origin.

Expand All @@ -447,6 +451,10 @@ def _validate_redirect_url(redirect_url: str, host_header: str) -> bool:
if not redirect_url:
return False

if allow_localhost:
if redirect_url == "http://localhost:3000" or redirect_url == "http://localhost:3000/Prod":
return True

try:
parsed = urlparse(redirect_url)

Expand Down Expand Up @@ -490,7 +498,7 @@ def login(event, context):
host_header = event.get("headers", {}).get("Host") or event.get("headers", {}).get("host", "")

# Validate redirect URL
if not _validate_redirect_url(redirect_url, host_header):
if not _validate_redirect_url(redirect_url, host_header, config["ALLOW_LOCALHOST"]):
logger.warning(f"Invalid redirect URL: {redirect_url}")
redirect_url = root_path

Expand Down Expand Up @@ -840,9 +848,19 @@ def callback(event, context):
secure_flag = should_set_secure_flag(host_header)
domain = extract_domain_from_host(host_header)
root_path = _get_root_path(event)
same_site = "Strict"

# For localhost development: use None for cross-site requests
if config["ALLOW_LOCALHOST"]:
same_site = "None" # Allows cross-site requests (localhost -> AWS)

session_cookie = create_session_cookie(
session_id=session_id, max_age_seconds=int(expires_at), domain=domain, secure=secure_flag, path=root_path
session_id=session_id,
max_age_seconds=int(expires_at),
domain=domain,
secure=secure_flag,
path=root_path,
same_site=same_site,
)

# Clear state cookie
Expand Down Expand Up @@ -1222,9 +1240,19 @@ def _attempt_token_refresh(
secure_flag = should_set_secure_flag(host_header)
domain = extract_domain_from_host(host_header)
root_path = _get_root_path(event)
same_site = "Strict"

# For localhost development: use None for cross-site requests
if config.get("ALLOW_LOCALHOST"):
same_site = "None" # Allows cross-site requests (localhost -> AWS)

new_session_cookie = create_session_cookie(
session_id=session_id, max_age_seconds=int(access_expires), domain=domain, secure=secure_flag, path=root_path
session_id=session_id,
max_age_seconds=int(access_expires),
domain=domain,
secure=secure_flag,
path=root_path,
same_site=same_site,
)

logger.info(f"Token refresh successful for session: {session_id}")
Expand Down Expand Up @@ -1457,12 +1485,22 @@ def _set_session_cookie_for_domain(session_id, event, config) -> str:
secure_flag = should_set_secure_flag(host_header)
domain = extract_domain_from_host(host_header)
root_path = _get_root_path(event)
same_site = "Strict"

# For localhost development: use None for cross-site requests
if config.get("ALLOW_LOCALHOST"):
same_site = "None" # Allows cross-site requests (localhost -> AWS)

# Use default session TTL (24 hours) for sync cookies
session_ttl = int(config.get("session_ttl_hours", "24")) * 3600

return create_session_cookie(
session_id=session_id, max_age_seconds=session_ttl, domain=domain, secure=secure_flag, path=root_path
session_id=session_id,
max_age_seconds=session_ttl,
domain=domain,
secure=secure_flag,
path=root_path,
same_site=same_site,
)


Expand Down
50 changes: 50 additions & 0 deletions backend/src/ml_space_lambda/auth/utils/key_rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from typing import Dict, Optional

import boto3
from botocore.exceptions import ClientError
from pydantic import ValidationError

from ml_space_lambda.auth.models.key_models import (
KeyRotationResult,
Expand Down Expand Up @@ -56,6 +58,34 @@ def _get_default_keep_versions() -> int:
return 3


def _secret_already_versioned_for_type(secret_id: str, expected: KeyType, secrets_client) -> bool:
"""
Return True if the secret string parses as VersionedKeyData with matching key_type and keys.

AWS API and permission errors propagate so deploy does not overwrite a valid secret after a
transient failure. Only empty, non-JSON, or structurally invalid payloads are treated as
not-yet-versioned.
"""
try:
response = secrets_client.get_secret_value(SecretId=secret_id)
except ClientError:
raise

raw = response.get("SecretString") or ""
if not raw.strip():
return False
try:
data = VersionedKeyData.from_secrets_manager_format(raw)
except (json.JSONDecodeError, ValidationError, ValueError, TypeError) as e:
logger.info(
"Secret %s not in expected versioned JSON format; will initialize if needed: %s",
secret_id,
e,
)
return False
return data.key_type == expected and bool(data.keys)


def initialize_state_encryption_key(secret_arn: str) -> Dict:
"""
Initialize state encryption key secret with versioned structure.
Expand All @@ -72,6 +102,14 @@ def initialize_state_encryption_key(secret_arn: str) -> Dict:
try:
secrets_client = boto3.client("secretsmanager")

if _secret_already_versioned_for_type(secret_arn, KeyType.STATE, secrets_client):
logger.info("State encryption secret already in versioned JSON format; skipping initialization.")
return {
"success": True,
"skipped": True,
"key_type": KeyType.STATE,
}

# Generate initial Fernet key
initial_key = create_state_encryption_key()
encoded_key = encode_state_key_for_storage(initial_key)
Expand All @@ -93,6 +131,8 @@ def initialize_state_encryption_key(secret_arn: str) -> Dict:
"created_date": key_data.created_date.isoformat(),
}

except ClientError:
raise
except Exception as e:
logger.error(f"State key initialization failed: {e}")
raise Exception(f"State key initialization failed: {e}")
Expand All @@ -114,6 +154,14 @@ def initialize_token_encryption_key(secret_arn: str) -> Dict:
try:
secrets_client = boto3.client("secretsmanager")

if _secret_already_versioned_for_type(secret_arn, KeyType.TOKEN, secrets_client):
logger.info("Token encryption secret already in versioned JSON format; skipping initialization.")
return {
"success": True,
"skipped": True,
"key_type": KeyType.TOKEN,
}

# Generate initial PASETO key
initial_key = create_encryption_key()
encoded_key = encode_key_for_storage(initial_key)
Expand All @@ -135,6 +183,8 @@ def initialize_token_encryption_key(secret_arn: str) -> Dict:
"created_date": key_data.created_date.isoformat(),
}

except ClientError:
raise
except Exception as e:
logger.error(f"Token key initialization failed: {e}")
raise Exception(f"Token key initialization failed: {e}")
Expand Down
22 changes: 22 additions & 0 deletions backend/src/ml_space_lambda/auth/utils/rotation_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,22 @@ def token_key_secrets_manager_rotation_handler(event: Dict[str, Any], context: A
raise


def deploy_time_auth_secrets_init(event: Dict[str, Any], context: Any) -> Dict[str, Any]:
"""
Deploy-time hook: ensure state and token secrets use VersionedKeyData JSON (not plaintext).

Secret names come from environment variables set by CDK.
"""
state_name = os.environ.get("AUTH_STATE_SECRET_NAME", "").strip()
token_name = os.environ.get("AUTH_TOKEN_SECRET_NAME", "").strip()
results: Dict[str, Any] = {}
if state_name:
results["state"] = initialize_secret_handler({"secret_name": state_name, "key_type": KeyType.STATE}, context)
if token_name:
results["token"] = initialize_secret_handler({"secret_name": token_name, "key_type": KeyType.TOKEN}, context)
return results


def initialize_secret_handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]:
"""
Handler for initializing secrets with proper key structures.
Expand All @@ -238,6 +254,12 @@ def initialize_secret_handler(event: Dict[str, Any], context: Any) -> Dict[str,
if not secret_name:
raise ValueError("secret_name is required")

if isinstance(key_type, str):
try:
key_type = KeyType(key_type)
except ValueError as e:
raise ValueError(f"Unknown key_type: {key_type}") from e

if key_type == KeyType.STATE:
result = initialize_state_encryption_key(secret_name)
elif key_type == KeyType.TOKEN:
Expand Down
Loading
Loading