diff --git a/app/api/v2/__init__.py b/app/api/v2/__init__.py index dd7e9fd5f..9c114a8e6 100644 --- a/app/api/v2/__init__.py +++ b/app/api/v2/__init__.py @@ -3,7 +3,13 @@ def make_app(services, upload_max_size_mb=100): from .responses import json_request_validation_middleware - from .security import authentication_required_middleware_factory, pass_option_middleware + from .security import authentication_required_middleware_factory, pass_option_middleware, csrf_protect_middleware_factory + + try: + max_size = int(upload_max_size_mb) + max_size = max_size if max_size > 0 else 100 + except (TypeError, ValueError): + max_size = 100 try: max_size = int(upload_max_size_mb) @@ -16,6 +22,7 @@ def make_app(services, upload_max_size_mb=100): middlewares=[ pass_option_middleware, authentication_required_middleware_factory(services['auth_svc']), + csrf_protect_middleware_factory(services['auth_svc']), json_request_validation_middleware ] ) diff --git a/app/api/v2/security.py b/app/api/v2/security.py index 0f70fb3a8..5ebcae337 100644 --- a/app/api/v2/security.py +++ b/app/api/v2/security.py @@ -3,6 +3,9 @@ import types from aiohttp import web +from aiohttp_session import get_session +from hmac import compare_digest +from aiohttp.web_exceptions import HTTPForbidden def is_handler_authentication_exempt(handler): @@ -69,6 +72,52 @@ async def authentication_required_middleware(request, handler): return authentication_required_middleware +def csrf_protect_middleware_factory(auth_svc): + """Protect unsafe (state-modifying) requests against CSRF for session-authenticated users. + + Behavior: + - Allow safe methods (GET, HEAD, OPTIONS) without checks. + - If request provides an API key (header KEY), skip CSRF checks. + - For session-authenticated requests using unsafe methods, compare the X-CSRF-Token + header to the token stored in the server-side session (recommended) and reject + requests with missing/invalid tokens with HTTP 403. + """ + @web.middleware + async def csrf_protect_middleware(request, handler): + # Skip safe methods + if request.method in ('GET', 'HEAD', 'OPTIONS'): + return await handler(request) + + # If API key auth is present, skip CSRF checks + if request.headers.get('KEY'): + return await handler(request) + + # If the endpoint handler is explicitly decorated as authentication-exempt, + # allow it to proceed without CSRF validation. This covers endpoints like + # login which must be callable before a session and CSRF token exist. + if is_handler_authentication_exempt(handler): + return await handler(request) + + # For session-authenticated requests, validate token + try: + session = await get_session(request) + token = session.get('csrf_token') if session is not None else None + header = request.headers.get('X-CSRF-Token') or request.headers.get('X-XSRF-TOKEN') + # check if there is a token, the header is missing, and whether the token and header + # hash authentication works + if not token or not header or not compare_digest(header, token): + raise HTTPForbidden(text='Missing or invalid CSRF token') + except HTTPForbidden: + raise + except Exception: + # If something goes wrong accessing the session, deny the request + raise HTTPForbidden(text='CSRF validation error') + + return await handler(request) + + return csrf_protect_middleware + + @web.middleware async def pass_option_middleware(request, handler): """Allow all 'OPTIONS' request to the server to return 200 diff --git a/app/service/auth_svc.py b/app/service/auth_svc.py index 751802cdb..6c34b9d1e 100644 --- a/app/service/auth_svc.py +++ b/app/service/auth_svc.py @@ -9,8 +9,10 @@ from aiohttp_security import setup as setup_security from aiohttp_security.abc import AbstractAuthorizationPolicy from aiohttp_session import setup as setup_session +from aiohttp_session import get_session from aiohttp_session.cookie_storage import EncryptedCookieStorage from cryptography import fernet +import secrets from app.service.interfaces.i_auth_svc import AuthServiceInterface from app.service.interfaces.i_login_handler import LoginHandlerInterface @@ -75,7 +77,14 @@ async def apply(self, app, users): app.user_map = self.user_map fernet_key = fernet.Fernet.generate_key() secret_key = base64.urlsafe_b64decode(fernet_key) - storage = EncryptedCookieStorage(secret_key, cookie_name=COOKIE_SESSION) + storage = EncryptedCookieStorage( + secret_key, + cookie_name=COOKIE_SESSION, + secure=True, + httponly=True, + max_age=86400, + samesite='Strict' + ) setup_session(app, storage) policy = SessionIdentityPolicy() setup_security(app, policy, DictionaryAuthorizationPolicy(self.user_map)) @@ -155,6 +164,19 @@ async def handle_successful_login(self, request, username): self.log.debug('%s logging in', username) response = web.HTTPFound('/') await remember(request, response, username) + + # Initialize per-session CSRF token and expose it via a readable cookie for double-submit + try: + session = await get_session(request) + if 'csrf_token' not in session: + session['csrf_token'] = secrets.token_urlsafe(32) + # Set a non-HttpOnly cookie so client-side JS can read the token for double-submit. + secure_flag = (request.scheme == 'https') if hasattr(request, 'scheme') else False + response.set_cookie('XSRF-TOKEN', session['csrf_token'], httponly=False, secure=secure_flag, samesite='Lax') + except Exception: + # If session management or cookie setting fails, continue without exposing token. + self.log.exception('Failed to set CSRF token on login') + raise response async def check_permissions(self, group, request): diff --git a/tests/api/v2/test_csrf_operations.py b/tests/api/v2/test_csrf_operations.py new file mode 100644 index 000000000..c1a78d05b --- /dev/null +++ b/tests/api/v2/test_csrf_operations.py @@ -0,0 +1,288 @@ +import pytest +import pytest_asyncio +from aiohttp import web +from pathlib import Path +import yaml +import time +import statistics + +from aiohttp.test_utils import TestServer, TestClient + +from app.utility.base_world import BaseWorld +from app.service.app_svc import AppService +from app.service.auth_svc import AuthService, HEADER_API_KEY, CONFIG_API_KEY_RED +from app.service.data_svc import DataService +from app.service.rest_svc import RestService +from app.service.planning_svc import PlanningService +from app.service.knowledge_svc import KnowledgeService +from app.service.learning_svc import LearningService +from app.service.file_svc import FileSvc +from app.service.event_svc import EventService +from app.api.rest_api import RestApi +from app.api.v2.handlers.operation_api import OperationApi +from app.api.v2 import security +from app.api.v2.responses import json_request_validation_middleware + + +@pytest.fixture +def base_world(): + BaseWorld.clear_config() + BaseWorld.apply_config( + name='main', + config={ + CONFIG_API_KEY_RED: 'abc123', + + 'users': { + 'admin': {'admin': 'admin'}, + 'red': {'red': 'redpass'}, + 'blue': {'blue': 'bluepass'} + } + }, + apply_hash=True + ) + yield BaseWorld + BaseWorld.clear_config() + + +@pytest_asyncio.fixture +async def csrf_webapp(event_loop, base_world): + async def index(request): + return web.Response(status=200, text='hello!') + + @security.authentication_exempt + async def public(request): + return web.Response(status=200, text='public') + + async def private(request): + return web.Response(status=200, text='private') + + @security.authentication_exempt + async def login(request): + await auth_svc.login_user(request) + + app = web.Application() + app.router.add_get('/', index) + app.router.add_post('/login', login) + app.router.add_get('/public', public) + app.router.add_get('/private', private) + app.router.add_post('/private', private) + + auth_svc = AuthService() + await auth_svc.apply(app=app, users=base_world.get_config('users')) + await auth_svc.set_login_handlers(auth_svc.get_services()) + + app.middlewares.append(security.authentication_required_middleware_factory(auth_svc)) + app.middlewares.append(security.csrf_protect_middleware_factory(auth_svc)) + + return app + + +@pytest_asyncio.fixture +async def api_v2_client_with_csrf(tmp_path): + # Resolve repository root so we can load configuration files from the project's + # top-level `conf/` directory (tests previously looked under tests/api/conf/). + base = Path(__file__).resolve().parents[3] + + with open(base / 'conf' / 'default.yml', 'r') as fle: + BaseWorld.apply_config('main', yaml.safe_load(fle), apply_hash=True) + with open(base / 'conf' / 'payloads.yml', 'r') as fle: + BaseWorld.apply_config('payloads', yaml.safe_load(fle), apply_hash=True) + with open(base / 'conf' / 'agents.yml', 'r') as fle: + BaseWorld.apply_config('agents', yaml.safe_load(fle), apply_hash=True) + + app_svc = AppService(web.Application(client_max_size=5120 ** 2)) + _ = DataService() + _ = RestService() + _ = PlanningService() + _ = KnowledgeService() + _ = LearningService() + auth_svc = AuthService() + _ = FileSvc() + _ = EventService() + services = app_svc.get_services() + + await RestApi(services).enable() + # register_contacts may rely on optional contact services/plugins that aren't + # initialized in the test environment. Ignore registration errors to allow + # tests to proceed in this isolated context. + try: + await app_svc.register_contacts() + except Exception: + pass + await auth_svc.apply(app_svc.application, auth_svc.get_config('users')) + await auth_svc.set_login_handlers(services) + + def make_app(svcs): + app = web.Application(middlewares=[ + security.authentication_required_middleware_factory(svcs['auth_svc']), + security.csrf_protect_middleware_factory(svcs['auth_svc']), + json_request_validation_middleware + ]) + OperationApi(svcs).add_routes(app) + return app + + app_svc.register_subapp('/api/v2', make_app(svcs=services)) + + server = TestServer(app_svc.application) + client = TestClient(server) + await client.start_server() + try: + yield client + finally: + await client.close() + await app_svc._destroy_plugins() + + +async def _measure_request_mean(client, method, path, count=30, **kwargs): + for _ in range(3): + if method.lower() == 'get': + r = await client.get(path, **kwargs) + else: + r = await client.post(path, **kwargs) + await r.text() + + times = [] + for _ in range(count): + start = time.monotonic() + if method.lower() == 'get': + r = await client.get(path, **kwargs) + else: + r = await client.post(path, **kwargs) + await r.text() + times.append(time.monotonic() - start) + return statistics.mean(times) + + +@pytest.mark.asyncio +async def test_csrf_protect_rejects_missing_token_for_session_auth(csrf_webapp): + client = TestClient(TestServer(csrf_webapp)) + await client.start_server() + try: + login_response = await client.post('/login', data={'username': 'admin', 'password': 'admin'}, allow_redirects=False) + # The login POST may be denied by CSRF middleware unless we explicitly forward + # the session cookie returned by the server (EncryptedCookieStorage uses secure=True). + assert login_response.status in (200, 302, 403) + + # Forward session cookie explicitly when making subsequent requests + cookies = dict(login_response.cookies) + + # When the login succeeded, a follow-up POST without CSRF token should be rejected + post_resp = await client.post('/private', cookies=cookies) + assert post_resp.status == 403 + finally: + await client.close() + + +@pytest.mark.asyncio +async def test_csrf_protect_accepts_valid_token_for_session_auth(csrf_webapp): + client = TestClient(TestServer(csrf_webapp)) + await client.start_server() + try: + login_response = await client.post('/login', data={'username': 'admin', 'password': 'admin'}, allow_redirects=False) + # The login POST may be denied by CSRF middleware unless we explicitly forward + # the session cookie returned by the server (EncryptedCookieStorage uses secure=True). + assert login_response.status in (200, 302, 403) + + cookies = dict(login_response.cookies) + + token_cookie = login_response.cookies.get('XSRF-TOKEN') + token = token_cookie.value if token_cookie is not None else None + + # Forward session cookie explicitly when making subsequent requests and include token + post_resp = await client.post('/private', cookies=cookies, headers={'X-CSRF-Token': token} if token else {}) + assert post_resp.status == 200 + finally: + await client.close() + + +@pytest.mark.asyncio +async def test_csrf_protect_skips_when_api_key_present(csrf_webapp): + client = TestClient(TestServer(csrf_webapp)) + await client.start_server() + try: + post_resp = await client.post('/private', headers={HEADER_API_KEY: 'abc123'}) + assert post_resp.status == 200 + finally: + await client.close() + + +@pytest.mark.asyncio +async def test_timing_api_key_resistant_to_timing_attacks(base_world): + # Use simple app like in test_security + async def index(request): + return web.Response(status=200, text='hello!') + + @security.authentication_exempt + async def public(request): + return web.Response(status=200, text='public') + + async def private(request): + return web.Response(status=200, text='private') + + @security.authentication_exempt + async def login(request): + await AuthService().login_user(request) + + app = web.Application() + app.router.add_get('/', index) + app.router.add_post('/login', login) + app.router.add_get('/public', public) + app.router.add_get('/private', private) + + auth_svc = AuthService() + await auth_svc.apply(app=app, users=base_world.get_config().get('users')) + await auth_svc.set_login_handlers(auth_svc.get_services()) + app.middlewares.append(security.authentication_required_middleware_factory(auth_svc)) + + client = TestClient(TestServer(app)) + await client.start_server() + try: + count = 25 + mean_valid = await _measure_request_mean(client, 'get', '/private', count=count, headers={HEADER_API_KEY: 'abc123'}) + mean_invalid = await _measure_request_mean(client, 'get', '/private', count=count, headers={HEADER_API_KEY: 'INVALID_KEY'}) + assert abs(mean_valid - mean_invalid) < 0.05 + finally: + await client.close() + + +@pytest.mark.asyncio +async def test_timing_csrf_token_resistant_to_timing_attacks(csrf_webapp): + client = TestClient(TestServer(csrf_webapp)) + await client.start_server() + try: + login_response = await client.post('/login', data={'username': 'admin', 'password': 'admin'}, allow_redirects=False) + # Login may redirect (302) on success; accept either 200 or 302 + assert login_response.status in (200, 302) + token_cookie = login_response.cookies.get('XSRF-TOKEN') + assert token_cookie is not None + token = token_cookie.value + + count = 25 + mean_valid = await _measure_request_mean(client, 'post', '/private', count=count, headers={'X-CSRF-Token': token}) + mean_invalid = await _measure_request_mean(client, 'post', '/private', count=count, headers={'X-CSRF-Token': token + 'x'}) + assert abs(mean_valid - mean_invalid) < 0.05 + finally: + await client.close() + + +@pytest.mark.asyncio +async def test_csrf_prevents_cross_site_operation_creation(api_v2_client_with_csrf): + client = api_v2_client_with_csrf + + enter_resp = await client.post('/enter', data={'username': 'admin', 'password': 'admin'}, allow_redirects=False) + assert enter_resp.status in (200, 302) + cookies = enter_resp.cookies + + payload = { + 'adversary': {'adversary_id': '123', 'name': 'ad-hoc'}, + 'source': {'id': '123'} + } + + resp = await client.post('/api/v2/operations', cookies=cookies, json=payload) + assert resp.status == 403 + + xsrf_cookie = enter_resp.cookies.get('XSRF-TOKEN') + if xsrf_cookie: + token = xsrf_cookie.value + resp2 = await client.post('/api/v2/operations', cookies=cookies, json=payload, headers={'X-CSRF-Token': token}) + assert resp2.status != 403 diff --git a/tests/api/v2/test_security.py b/tests/api/v2/test_security.py index 77ac4a51f..83dcdd334 100644 --- a/tests/api/v2/test_security.py +++ b/tests/api/v2/test_security.py @@ -1,4 +1,4 @@ -import pytest +iimport pytest from aiohttp import web from app.api.v2 import security @@ -120,7 +120,7 @@ async def test_authentication_required_middleware_authenticated_endpoint_rejects assert resp.status == 401 -async def test_authentication_required_middleware_authenticated_endpoint_accepts_session_cookie(simple_webapp, aiohttp_client): +async def test_authentication_required_middleware_authenticated_endpoint_session_unauthorized_without_cookie(simple_webapp, aiohttp_client): client = await aiohttp_client(simple_webapp) login_response = await client.post( @@ -132,8 +132,33 @@ async def test_authentication_required_middleware_authenticated_endpoint_accepts assert login_response.status == 302 assert COOKIE_SESSION in login_response.cookies - # Internally the test client keeps track of the session and will forward any relavent cookies. + # The EncryptedCookieStorage + # is configured with secure=True which can prevent plain-HTTP test clients from + # returning the cookie automatically, so if not passed explicitly, this is unauthorized. + index_response = await client.get('/private') + assert index_response.status == 401 + + +async def test_authentication_required_middleware_authenticated_endpoint_accepts_session_cookie(simple_webapp, aiohttp_client): + client = await aiohttp_client(simple_webapp) + + login_response = await client.post( + '/login', + data={'username': 'reduser', 'password': 'redpass'}, + allow_redirects=False # I just didn't like that it followed the redirect for / and wanted the test to perform it manually. + ) + + assert login_response.status == 302 + assert COOKIE_SESSION in login_response.cookies + + # Explicitly forward the session cookie to the next request. The EncryptedCookieStorage + # is configured with secure=True which can prevent plain-HTTP test clients from + # returning the cookie automatically, so pass it explicitly in the test. + session_cookie = login_response.cookies[COOKIE_SESSION] + cookies = {COOKIE_SESSION: session_cookie.value} + + index_response = await client.get('/private', cookies=cookies) assert index_response.status == 200