diff --git a/src/event_gate_lambda.py b/src/event_gate_lambda.py index 6855a17..92be944 100644 --- a/src/event_gate_lambda.py +++ b/src/event_gate_lambda.py @@ -14,7 +14,10 @@ # limitations under the License. # -"""Event Gate Lambda function implementation.""" +""" +This module contains the AWS Lambda entry point for the EventGate service. +""" + import json import logging import os @@ -24,6 +27,7 @@ import boto3 from botocore.exceptions import BotoCoreError, NoCredentialsError +from src.handlers.handler_api import HandlerApi from src.handlers.handler_token import HandlerToken from src.handlers.handler_topic import HandlerTopic from src.handlers.handler_health import HandlerHealth @@ -34,14 +38,15 @@ from src.writers.writer_postgres import WriterPostgres from src.utils.conf_path import CONF_DIR, INVALID_CONF_ENV -PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) # Initialize logger -logger = logging.getLogger(__name__) +root_logger = logging.getLogger() +if not root_logger.handlers: + root_logger.addHandler(logging.StreamHandler()) + log_level = os.environ.get("LOG_LEVEL", "INFO") -logger.setLevel(log_level) -if not logger.handlers: - logger.addHandler(logging.StreamHandler()) +root_logger.setLevel(log_level) +logger = logging.getLogger(__name__) logger.debug("Initialized logger with level %s", log_level) # Load main configuration @@ -52,11 +57,6 @@ config = json.load(file) logger.debug("Loaded main configuration") -# Load API definition -with open(os.path.join(PROJECT_ROOT, "api.yaml"), "r", encoding="utf-8") as file: - API = file.read() -logger.debug("Loaded API definition") - # Initialize S3 client with SSL verification try: ssl_verify = config.get(SSL_CA_BUNDLE_KEY, True) @@ -66,21 +66,6 @@ logger.exception("Failed to initialize AWS S3 client") raise RuntimeError("AWS S3 client initialization failed") from exc -# Load access configuration -ACCESS: Dict[str, list[str]] = {} -if config["access_config"].startswith("s3://"): - name_parts = config["access_config"].split("/") - BUCKET_NAME = name_parts[2] - BUCKET_OBJECT_KEY = "/".join(name_parts[3:]) - ACCESS = json.loads(aws_s3.Bucket(BUCKET_NAME).Object(BUCKET_OBJECT_KEY).get()["Body"].read().decode("utf-8")) -else: - with open(config["access_config"], "r", encoding="utf-8") as file: - ACCESS = json.load(file) -logger.debug("Loaded access configuration") - -# Initialize token handler and load token public keys -handler_token = HandlerToken(config).load_public_keys() - # Initialize EventGate writers writers = { "kafka": WriterKafka(config), @@ -88,21 +73,28 @@ "postgres": WriterPostgres(config), } -# Initialize topic handler and load topic schemas -handler_topic = HandlerTopic(CONF_DIR, ACCESS, handler_token, writers).load_topic_schemas() - -# Initialize health handler +# Initialize EventGate handlers +handler_token = HandlerToken(config).with_public_keys_queried() +handler_topic = HandlerTopic(config, aws_s3, handler_token, writers).with_load_access_config().with_load_topic_schemas() handler_health = HandlerHealth(writers) +handler_api = HandlerApi().with_api_definition_loaded() -def get_api() -> Dict[str, Any]: - """Return the OpenAPI specification text.""" - return {"statusCode": 200, "body": API} +# Route to handler function mapping +ROUTE_MAP: Dict[str, Any] = { + "/api": lambda _: handler_api.get_api(), + "/token": lambda _: handler_token.get_token_provider_info(), + "/health": lambda _: handler_health.get_health(), + "/topics": lambda _: handler_topic.get_topics_list(), + "/topics/{topic_name}": handler_topic.handle_request, + "/terminate": lambda _: sys.exit("TERMINATING"), +} def lambda_handler(event: Dict[str, Any], _context: Any = None) -> Dict[str, Any]: """ AWS Lambda entry point. Dispatches based on API Gateway proxy 'resource' and 'httpMethod'. + Args: event: The event data from API Gateway. _context: The mandatory context argument for AWS Lambda invocation (unused). @@ -113,26 +105,11 @@ def lambda_handler(event: Dict[str, Any], _context: Any = None) -> Dict[str, Any """ try: resource = event.get("resource", "").lower() - if resource == "/api": - return get_api() - if resource == "/token": - return handler_token.get_token_provider_info() - if resource == "/health": - return handler_health.get_health() - if resource == "/topics": - return handler_topic.get_topics_list() - if resource == "/topics/{topic_name}": - method = event.get("httpMethod") - if method == "GET": - return handler_topic.get_topic_schema(event["pathParameters"]["topic_name"].lower()) - if method == "POST": - return handler_topic.post_topic_message( - event["pathParameters"]["topic_name"].lower(), - json.loads(event["body"]), - handler_token.extract_token(event.get("headers", {})), - ) - if resource == "/terminate": - sys.exit("TERMINATING") + route_function = ROUTE_MAP.get(resource) + + if route_function: + return route_function(event) + return build_error_response(404, "route", "Resource not found") except (KeyError, json.JSONDecodeError, ValueError, AttributeError, TypeError, RuntimeError) as request_exc: logger.exception("Request processing error: %s", request_exc) diff --git a/src/handlers/handler_api.py b/src/handlers/handler_api.py new file mode 100644 index 0000000..898775d --- /dev/null +++ b/src/handlers/handler_api.py @@ -0,0 +1,73 @@ +# +# Copyright 2026 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +This module provides the HandlerApi class for serving the OpenAPI specification. +""" + +import logging +import os +from typing import Dict, Any + +logger = logging.getLogger(__name__) + + +class HandlerApi: + """ + HandlerApi manages the OpenAPI specification endpoint. + """ + + def __init__(self): + self.api_spec: str = "" + + def with_api_definition_loaded(self) -> "HandlerApi": + """ + Load the OpenAPI specification from api.yaml file. + + Returns: + HandlerApi: The current instance with loaded API definition. + Raises: + RuntimeError: If loading or reading the API specification fails. + """ + project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) + api_path = os.path.join(project_root, "api.yaml") + + try: + with open(api_path, "r", encoding="utf-8") as file: + self.api_spec = file.read() + + if not self.api_spec: + raise ValueError("API specification file is empty") + + logger.debug("Loaded API definition from %s", api_path) + return self + except (FileNotFoundError, PermissionError, ValueError) as exc: + logger.exception("Failed to load or read API specification from %s", api_path) + raise RuntimeError("API specification initialization failed") from exc + + def get_api(self) -> Dict[str, Any]: + """ + Return the OpenAPI specification. + + Returns: + Dict[str, Any]: API Gateway response with OpenAPI spec. + """ + logger.debug("Handling GET API") + return { + "statusCode": 200, + "headers": {"Content-Type": "application/yaml"}, + "body": self.api_spec, + } diff --git a/src/handlers/handler_token.py b/src/handlers/handler_token.py index 1762728..4af1edb 100644 --- a/src/handlers/handler_token.py +++ b/src/handlers/handler_token.py @@ -70,11 +70,11 @@ def _refresh_keys_if_needed(self) -> None: return try: logger.debug("Token public keys are stale, refreshing now") - self.load_public_keys() + self.with_public_keys_queried() except RuntimeError: logger.warning("Token public key refresh failed, using existing keys") - def load_public_keys(self) -> "HandlerToken": + def with_public_keys_queried(self) -> "HandlerToken": """ Load token public keys from the configured URL. Returns: diff --git a/src/handlers/handler_topic.py b/src/handlers/handler_topic.py index ed14143..1bd26a0 100644 --- a/src/handlers/handler_topic.py +++ b/src/handlers/handler_topic.py @@ -23,16 +23,16 @@ from typing import Dict, Any import jwt +from boto3.resources.base import ServiceResource from jsonschema import validate from jsonschema.exceptions import ValidationError from src.handlers.handler_token import HandlerToken +from src.utils.conf_path import CONF_DIR from src.utils.utils import build_error_response from src.writers.writer import Writer logger = logging.getLogger(__name__) -log_level = os.environ.get("LOG_LEVEL", "INFO") -logger.setLevel(log_level) class HandlerTopic: @@ -42,24 +42,48 @@ class HandlerTopic: def __init__( self, - conf_dir: str, - access_config: Dict[str, list[str]], + config: Dict[str, Any], + aws_s3: ServiceResource, handler_token: HandlerToken, writers: Dict[str, Writer], ): - self.conf_dir = conf_dir - self.access_config = access_config + self.config = config + self.aws_s3 = aws_s3 self.handler_token = handler_token self.writers = writers + self.access_config: Dict[str, list[str]] = {} self.topics: Dict[str, Dict[str, Any]] = {} - def load_topic_schemas(self) -> "HandlerTopic": + def with_load_access_config(self) -> "HandlerTopic": + """ + Load access control configuration from S3 or local file. + Returns: + HandlerTopic: The current instance with loaded access config. + """ + access_path = self.config["access_config"] + logger.debug("Loading access configuration from %s", access_path) + + if access_path.startswith("s3://"): + name_parts = access_path.split("/") + bucket_name = name_parts[2] + bucket_object_key = "/".join(name_parts[3:]) + self.access_config = json.loads( + self.aws_s3.Bucket(bucket_name).Object(bucket_object_key).get()["Body"].read().decode("utf-8") + ) + else: + with open(access_path, "r", encoding="utf-8") as file: + self.access_config = json.load(file) + + logger.debug("Loaded access configuration") + return self + + def with_load_topic_schemas(self) -> "HandlerTopic": """ Load topic schemas from configuration files. Returns: HandlerTopic: The current instance with loaded topic schemas. """ - topic_schemas_dir = os.path.join(self.conf_dir, "topic_schemas") + topic_schemas_dir = os.path.join(CONF_DIR, "topic_schemas") logger.debug("Loading topic schemas from %s", topic_schemas_dir) with open(os.path.join(topic_schemas_dir, "runs.json"), "r", encoding="utf-8") as file: @@ -85,7 +109,29 @@ def get_topics_list(self) -> Dict[str, Any]: "body": json.dumps(list(self.topics)), } - def get_topic_schema(self, topic_name: str) -> Dict[str, Any]: + def handle_request(self, event: Dict[str, Any]) -> Dict[str, Any]: + """ + Handle GET/POST requests for /topics/{topic_name} resource. + + Args: + event: The API Gateway event containing path parameters, method, body, and headers. + Returns: + Dict[str, Any]: API Gateway response. + """ + topic_name = event["pathParameters"]["topic_name"].lower() + method = event.get("httpMethod") + + if method == "GET": + return self._get_topic_schema(topic_name) + if method == "POST": + return self._post_topic_message( + topic_name, + json.loads(event["body"]), + self.handler_token.extract_token(event.get("headers", {})), + ) + return build_error_response(404, "route", "Resource not found") + + def _get_topic_schema(self, topic_name: str) -> Dict[str, Any]: """ Return the JSON schema for a specific topic. Args: @@ -104,7 +150,7 @@ def get_topic_schema(self, topic_name: str) -> Dict[str, Any]: "body": json.dumps(self.topics[topic_name]), } - def post_topic_message(self, topic_name: str, topic_message: Dict[str, Any], token_encoded: str) -> Dict[str, Any]: + def _post_topic_message(self, topic_name: str, topic_message: Dict[str, Any], token_encoded: str) -> Dict[str, Any]: """ Validate auth and schema; dispatch message to all writers. Args: @@ -114,11 +160,16 @@ def post_topic_message(self, topic_name: str, topic_message: Dict[str, Any], tok Returns: Dict[str, Any]: API Gateway response indicating success or failure. Raises: + RuntimeError: If access configuration is not loaded. jwt.PyJWTError: If token decoding fails. ValidationError: If message validation fails. """ logger.debug("Handling POST TopicMessage(%s)", topic_name) + if not self.access_config: + logger.error("Access configuration not loaded") + raise RuntimeError("Access configuration not loaded") + try: token: Dict[str, Any] = self.handler_token.decode_jwt(token_encoded) except jwt.PyJWTError: # type: ignore[attr-defined] diff --git a/tests/handlers/test_handler_api.py b/tests/handlers/test_handler_api.py new file mode 100644 index 0000000..f25d8c9 --- /dev/null +++ b/tests/handlers/test_handler_api.py @@ -0,0 +1,48 @@ +# +# Copyright 2026 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +from unittest.mock import patch, mock_open + +from src.handlers.handler_api import HandlerApi + + +def test_load_api_definition_success(): + """Test successful loading of API definition.""" + mock_content = "openapi: 3.0.0\ninfo:\n title: Test API" + with patch("builtins.open", mock_open(read_data=mock_content)): + handler = HandlerApi().with_api_definition_loaded() + assert handler.api_spec == mock_content + + +def test_load_api_definition_file_not_found(): + """Test that RuntimeError is raised when api.yaml doesn't exist.""" + with patch("builtins.open", side_effect=FileNotFoundError("api.yaml not found")): + handler = HandlerApi() + with pytest.raises(RuntimeError, match="API specification initialization failed"): + handler.with_api_definition_loaded() + + +def test_get_api_returns_correct_response(): + """Test get_api returns correct response structure.""" + mock_content = "openapi: 3.0.0" + with patch("builtins.open", mock_open(read_data=mock_content)): + handler = HandlerApi().with_api_definition_loaded() + response = handler.get_api() + + assert response["statusCode"] == 200 + assert response["headers"]["Content-Type"] == "application/yaml" + assert response["body"] == mock_content diff --git a/tests/handlers/test_handler_token.py b/tests/handlers/test_handler_token.py index 487860d..70cf8b1 100644 --- a/tests/handlers/test_handler_token.py +++ b/tests/handlers/test_handler_token.py @@ -106,7 +106,7 @@ def test_refresh_keys_not_needed_when_keys_fresh(token_handler): token_handler._last_loaded_at = datetime.now(timezone.utc) - timedelta(minutes=10) token_handler.public_keys = [Mock(spec=RSAPublicKey)] - with patch.object(token_handler, "load_public_keys") as mock_load: + with patch.object(token_handler, "with_public_keys_queried") as mock_load: token_handler._refresh_keys_if_needed() mock_load.assert_not_called() @@ -116,7 +116,7 @@ def test_refresh_keys_triggered_when_keys_stale(token_handler): token_handler._last_loaded_at = datetime.now(timezone.utc) - timedelta(minutes=29) token_handler.public_keys = [Mock(spec=RSAPublicKey)] - with patch.object(token_handler, "load_public_keys") as mock_load: + with patch.object(token_handler, "with_public_keys_queried") as mock_load: token_handler._refresh_keys_if_needed() mock_load.assert_called_once() @@ -127,7 +127,7 @@ def test_refresh_keys_handles_load_failure_gracefully(token_handler): token_handler.public_keys = [old_key] token_handler._last_loaded_at = datetime.now(timezone.utc) - timedelta(minutes=29) - with patch.object(token_handler, "load_public_keys", side_effect=RuntimeError("Network error")): + with patch.object(token_handler, "with_public_keys_queried", side_effect=RuntimeError("Network error")): token_handler._refresh_keys_if_needed() assert token_handler.public_keys == [old_key] diff --git a/tests/handlers/test_handler_topic.py b/tests/handlers/test_handler_topic.py index ecfc7c1..cc69386 100644 --- a/tests/handlers/test_handler_topic.py +++ b/tests/handlers/test_handler_topic.py @@ -22,6 +22,52 @@ from src.handlers.handler_topic import HandlerTopic +## load_access_config() +def test_load_access_config_from_local_file(): + """Test loading access config from local file.""" + mock_handler_token = MagicMock() + mock_aws_s3 = MagicMock() + mock_writers = { + "kafka": MagicMock(), + "eventbridge": MagicMock(), + "postgres": MagicMock(), + } + config = {"access_config": "conf/access.json"} + handler = HandlerTopic(config, mock_aws_s3, mock_handler_token, mock_writers) + + access_data = {"public.cps.za.test": ["TestUser"]} + with patch("builtins.open", mock_open(read_data=json.dumps(access_data))): + result = handler.with_load_access_config() + + assert result is handler + assert handler.access_config == access_data + + +def test_load_access_config_from_s3(): + """Test loading access config from S3.""" + mock_handler_token = MagicMock() + mock_aws_s3 = MagicMock() + mock_writers = { + "kafka": MagicMock(), + "eventbridge": MagicMock(), + "postgres": MagicMock(), + } + config = {"access_config": "s3://my-bucket/path/to/access.json"} + handler = HandlerTopic(config, mock_aws_s3, mock_handler_token, mock_writers) + + access_data = {"public.cps.za.test": ["TestUser"]} + mock_body = MagicMock() + mock_body.read.return_value = json.dumps(access_data).encode("utf-8") + mock_aws_s3.Bucket.return_value.Object.return_value.get.return_value = {"Body": mock_body} + + result = handler.with_load_access_config() + + assert result is handler + assert handler.access_config == access_data + mock_aws_s3.Bucket.assert_called_once_with("my-bucket") + mock_aws_s3.Bucket.return_value.Object.assert_called_once_with("path/to/access.json") + + ## load_topic_schemas() def test_load_topic_schemas_success(): mock_handler_token = MagicMock() @@ -30,8 +76,9 @@ def test_load_topic_schemas_success(): "eventbridge": MagicMock(), "postgres": MagicMock(), } - access_config = {"public.cps.za.test": ["TestUser"]} - handler = HandlerTopic("conf", access_config, mock_handler_token, mock_writers) + config = {"access_config": "conf/access.json"} + mock_aws_s3 = MagicMock() + handler = HandlerTopic(config, mock_aws_s3, mock_handler_token, mock_writers) mock_schemas = { "runs.json": {"type": "object", "properties": {"run_id": {"type": "string"}}}, @@ -46,7 +93,7 @@ def mock_open_side_effect(file_path, *_args, **_kwargs): raise FileNotFoundError(file_path) with patch("builtins.open", side_effect=mock_open_side_effect): - result = handler.load_topic_schemas() + result = handler.with_load_topic_schemas() assert result is handler assert len(handler.topics) == 3 diff --git a/tests/test_conf_validation.py b/tests/test_conf_validation.py index ef4ec1c..da001b1 100644 --- a/tests/test_conf_validation.py +++ b/tests/test_conf_validation.py @@ -59,9 +59,9 @@ def test_access_json_structure(): assert all(isinstance(u, str) for u in users), f"All users for topic {topic} must be strings" -@pytest.mark.parametrize("topic_file", glob(os.path.join(CONF_DIR, "topic_*.json"))) +@pytest.mark.parametrize("topic_file", glob(os.path.join(CONF_DIR, "topic_schemas", "*.json"))) def test_topic_json_schemas_basic(topic_file): - assert topic_file, "No topic_*.json files found" + assert topic_file, "No *.json files found in topic_schemas/" schema = load_json(topic_file) assert schema.get("type") == "object", "Schema root 'type' must be 'object'" props = schema.get("properties") diff --git a/tests/test_event_gate_lambda_local_access.py b/tests/test_event_gate_lambda_local_access.py index 1f0023b..ed20932 100644 --- a/tests/test_event_gate_lambda_local_access.py +++ b/tests/test_event_gate_lambda_local_access.py @@ -56,4 +56,4 @@ def Bucket(self, name): # noqa: D401 egl_reloaded = importlib.reload(egl) assert not egl_reloaded.config["access_config"].startswith("s3://") # type: ignore[attr-defined] - assert egl_reloaded.ACCESS["public.cps.za.test"] == ["User"] # type: ignore[attr-defined] + assert egl_reloaded.handler_topic.access_config["public.cps.za.test"] == ["User"] # type: ignore[attr-defined]