diff --git a/CHANGELOG.md b/CHANGELOG.md index 52bc006893..ff6ab762d4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -110,6 +110,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ([#4078](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4171)) - `opentelemetry-instrumentation-aiohttp-server`: fix HTTP error inconsistencies ([#4175](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4175)) +- `opentelemetry-instrumentation-aws-lambda`, `opentelemetry-instrumentation`: fix improper handling of header casing + ([#4216](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/4216)) ### Breaking changes diff --git a/instrumentation/opentelemetry-instrumentation-aws-lambda/src/opentelemetry/instrumentation/aws_lambda/__init__.py b/instrumentation/opentelemetry-instrumentation-aws-lambda/src/opentelemetry/instrumentation/aws_lambda/__init__.py index e666250c63..32e4943316 100644 --- a/instrumentation/opentelemetry-instrumentation-aws-lambda/src/opentelemetry/instrumentation/aws_lambda/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-aws-lambda/src/opentelemetry/instrumentation/aws_lambda/__init__.py @@ -84,6 +84,7 @@ def custom_event_context_extractor(lambda_event): from opentelemetry.context.context import Context from opentelemetry.instrumentation.aws_lambda.package import _instruments from opentelemetry.instrumentation.aws_lambda.version import __version__ +from opentelemetry.instrumentation.cidict import CIDict from opentelemetry.instrumentation.instrumentor import BaseInstrumentor from opentelemetry.instrumentation.utils import unwrap from opentelemetry.metrics import MeterProvider, get_meter_provider @@ -176,7 +177,9 @@ def _default_event_context_extractor(lambda_event: Any) -> Context: ) if not isinstance(headers, dict): headers = {} - return get_global_textmap().extract(headers) + return get_global_textmap().extract( + CIDict(headers), + ) def _determine_parent_context( @@ -216,20 +219,21 @@ def _set_api_gateway_v1_proxy_attributes( span.set_attribute(HTTP_METHOD, lambda_event.get("httpMethod")) if lambda_event.get("headers"): - if "User-Agent" in lambda_event["headers"]: + headers = CIDict(lambda_event["headers"]) + if "User-Agent" in headers: span.set_attribute( HTTP_USER_AGENT, - lambda_event["headers"]["User-Agent"], + headers["User-Agent"], ) - if "X-Forwarded-Proto" in lambda_event["headers"]: + if "X-Forwarded-Proto" in headers: span.set_attribute( HTTP_SCHEME, - lambda_event["headers"]["X-Forwarded-Proto"], + headers["X-Forwarded-Proto"], ) - if "Host" in lambda_event["headers"]: + if "Host" in headers: span.set_attribute( NET_HOST_NAME, - lambda_event["headers"]["Host"], + headers["Host"], ) if "resource" in lambda_event: span.set_attribute(HTTP_ROUTE, lambda_event["resource"]) diff --git a/instrumentation/opentelemetry-instrumentation-aws-lambda/tests/test_aws_lambda_instrumentation_manual.py b/instrumentation/opentelemetry-instrumentation-aws-lambda/tests/test_aws_lambda_instrumentation_manual.py index 9f10a9f0fa..be5a64a340 100644 --- a/instrumentation/opentelemetry-instrumentation-aws-lambda/tests/test_aws_lambda_instrumentation_manual.py +++ b/instrumentation/opentelemetry-instrumentation-aws-lambda/tests/test_aws_lambda_instrumentation_manual.py @@ -325,6 +325,36 @@ def custom_event_context_extractor(lambda_event): expected_baggage=MOCK_W3C_BAGGAGE_VALUE, propagators="tracecontext,baggage", ), + TestCase( + name="case_insensitive_headers_uppercase", + custom_extractor=None, + context={ + "headers": { + TraceContextTextMapPropagator._TRACEPARENT_HEADER_NAME.upper(): MOCK_W3C_TRACE_CONTEXT_SAMPLED, + TraceContextTextMapPropagator._TRACESTATE_HEADER_NAME.upper(): f"{MOCK_W3C_TRACE_STATE_KEY}={MOCK_W3C_TRACE_STATE_VALUE},foo=1,bar=2", + } + }, + expected_traceid=MOCK_W3C_TRACE_ID, + expected_parentid=MOCK_W3C_PARENT_SPAN_ID, + expected_trace_state_len=3, + expected_state_value=MOCK_W3C_TRACE_STATE_VALUE, + xray_traceid=MOCK_XRAY_TRACE_CONTEXT_NOT_SAMPLED, + ), + TestCase( + name="case_insensitive_headers_mixedcase", + custom_extractor=None, + context={ + "headers": { + "TraceParent": MOCK_W3C_TRACE_CONTEXT_SAMPLED, + "tRaCeStAtE": f"{MOCK_W3C_TRACE_STATE_KEY}={MOCK_W3C_TRACE_STATE_VALUE},foo=1,bar=2", + } + }, + expected_traceid=MOCK_W3C_TRACE_ID, + expected_parentid=MOCK_W3C_PARENT_SPAN_ID, + expected_trace_state_len=3, + expected_state_value=MOCK_W3C_TRACE_STATE_VALUE, + xray_traceid=MOCK_XRAY_TRACE_CONTEXT_NOT_SAMPLED, + ), ] for test in tests: with self.subTest(test_name=test.name): @@ -400,6 +430,57 @@ def test_lambda_no_error_with_invalid_flush_timeout(self): test_env_patch.stop() + def test_api_gateway_v1_attributes_case_insensitivity(self): + AwsLambdaInstrumentor().instrument() + + mock_execute_lambda( + { + "httpMethod": "GET", + "headers": { + "user-agent": "lowercase-agent", + "host": "lowercase-host", + "x-forwarded-proto": "http", + }, + "resource": "/test", + "requestContext": { + "version": "1.0", + }, + } + ) + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + span = spans[0] + self.assertEqual( + span.attributes.get(HTTP_USER_AGENT), "lowercase-agent" + ) + self.assertEqual(span.attributes.get(NET_HOST_NAME), "lowercase-host") + self.assertEqual(span.attributes.get(HTTP_SCHEME), "http") + + self.memory_exporter.clear() + + mock_execute_lambda( + { + "httpMethod": "GET", + "headers": { + "uSeR-aGeNt": "mixed-agent", + "hOsT": "mixed-host", + "X-fOrWaRdEd-PrOtO": "https", + }, + "resource": "/test", + "requestContext": { + "version": "1.0", + }, + } + ) + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + span = spans[0] + self.assertEqual(span.attributes.get(HTTP_USER_AGENT), "mixed-agent") + self.assertEqual(span.attributes.get(NET_HOST_NAME), "mixed-host") + self.assertEqual(span.attributes.get(HTTP_SCHEME), "https") + def test_lambda_handles_multiple_consumers(self): test_env_patch = mock.patch.dict( "os.environ", diff --git a/opentelemetry-instrumentation/src/opentelemetry/instrumentation/cidict.py b/opentelemetry-instrumentation/src/opentelemetry/instrumentation/cidict.py new file mode 100644 index 0000000000..f112c41d0e --- /dev/null +++ b/opentelemetry-instrumentation/src/opentelemetry/instrumentation/cidict.py @@ -0,0 +1,93 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import ( + Any, + Generic, + Iterable, + Iterator, + Mapping, + MutableMapping, + Optional, + Tuple, + TypeVar, + Union, +) + +KT = TypeVar("KT") +VT = TypeVar("VT") + + +class CIDict(MutableMapping[KT, VT], Generic[KT, VT]): + def __init__( + self, + data: Optional[Union[Mapping[KT, VT], Iterable[Tuple[KT, VT]]]] = None, + ) -> None: + self._data: dict[KT, Tuple[KT, VT]] = {} + if data is None: + data = {} + self.update(data) + + @staticmethod + def _normalize_key(key: KT) -> KT: + if isinstance(key, str): + return key.lower() # type: ignore + return key + + def _get_entry(self, key: KT) -> Tuple[KT, VT]: + normalized_key = self._normalize_key(key) + if normalized_key in self._data: + return self._data[normalized_key] + raise KeyError(repr(key)) + + def original_key(self, key: KT) -> KT: + return self._get_entry(key)[0] + + def normalized_items(self) -> Iterable[Tuple[KT, VT]]: + return ((key, value[1]) for key, value in self._data.items()) + + def __setitem__(self, key: KT, value: VT, /) -> None: + self._data[self._normalize_key(key)] = (key, value) + + def __delitem__(self, key: KT, /) -> None: + try: + del self._data[self._normalize_key(key)] + except KeyError: + raise KeyError(repr(key)) from None + + def __getitem__(self, key: KT, /) -> VT: + return self._get_entry(key)[1] + + def __len__(self) -> int: + return len(self._data) + + def __iter__(self) -> Iterator[KT]: + return (key for key, _ in self._data.values()) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({dict(self.items())!r})" + + def __eq__(self, other: Any) -> bool: + if isinstance(other, CIDict): + return dict(self.normalized_items()) == dict( + other.normalized_items() + ) + if not isinstance(other, Mapping): + return False + ciother: CIDict[Any, Any] = CIDict(other) + return dict(self.normalized_items()) == dict( + ciother.normalized_items() + ) diff --git a/opentelemetry-instrumentation/tests/test_cidict.py b/opentelemetry-instrumentation/tests/test_cidict.py new file mode 100644 index 0000000000..9dca588bcf --- /dev/null +++ b/opentelemetry-instrumentation/tests/test_cidict.py @@ -0,0 +1,216 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=redefined-outer-name + +from typing import Any, Optional + +import pytest + +from opentelemetry.instrumentation.cidict import CIDict + + +@pytest.fixture +def simple_cidict() -> CIDict[str, int]: + return CIDict({"Alpha": 1, "Beta": 2, "Gamma": 3}) + + +@pytest.mark.parametrize( + "data, expected_len", + [ + (None, 0), + ({}, 0), + ({"a": 1, "b": 2}, 2), + ({"A": 1, "a": 2}, 1), + ], +) +def test_init_from_mapping( + data: Optional[dict[str, int]], expected_len: int +) -> None: + assert len(CIDict(data)) == expected_len + + +@pytest.mark.parametrize( + "pairs, expected_len", + [ + ([], 0), + ([("x", 10), ("y", 20)], 2), + ([("X", 1), ("x", 2)], 1), + ], +) +def test_init_from_iterable_of_pairs( + pairs: list[tuple[str, int]], expected_len: int +) -> None: + assert len(CIDict(pairs)) == expected_len + + +def test_init_non_string_keys() -> None: + cid: CIDict[int, str] = CIDict({1: "one", 2: "two"}) + assert cid[1] == "one" + assert cid[2] == "two" + + +@pytest.mark.parametrize("key", ["Alpha", "alpha", "ALPHA", "aLpHa"]) +def test_getitem_case_insensitive( + simple_cidict: CIDict[str, int], key: str +) -> None: + assert simple_cidict[key] == 1 + + +def test_getitem_missing_raises(simple_cidict: CIDict[str, int]) -> None: + with pytest.raises(KeyError): + _ = simple_cidict["missing"] + + +@pytest.mark.parametrize( + "set_key, get_key, value", + [ + ("Hello", "hello", 99), + ("WORLD", "world", 42), + ("MiXeD", "mixed", 7), + ], +) +def test_setitem_case_insensitive( + set_key: str, get_key: str, value: int +) -> None: + cid: CIDict[str, int] = CIDict() + cid[set_key] = value + assert cid[get_key] == value + + +def test_setitem_overwrite_updates_original_key() -> None: + cid: CIDict[str, int] = CIDict() + cid["Hello"] = 1 + cid["HELLO"] = 2 + assert cid["hello"] == 2 + assert cid.original_key("hello") == "HELLO" + + +@pytest.mark.parametrize("del_key", ["Alpha", "alpha", "ALPHA", "aLpHa"]) +def test_delitem_case_insensitive( + simple_cidict: CIDict[str, int], del_key: str +) -> None: + del simple_cidict[del_key] + assert "alpha" not in simple_cidict + assert len(simple_cidict) == 2 + + +def test_delitem_missing_raises(simple_cidict: CIDict[str, int]) -> None: + with pytest.raises(KeyError): + del simple_cidict["nonexistent"] + + +def test_iter_yields_original_keys(simple_cidict: CIDict[str, int]) -> None: + assert set(simple_cidict) == {"Alpha", "Beta", "Gamma"} + + +@pytest.mark.parametrize( + "lookup_key, expected_original", + [ + ("Alpha", "Alpha"), + ("alpha", "Alpha"), + ("ALPHA", "Alpha"), + ], +) +def test_original_key( + simple_cidict, lookup_key: str, expected_original: str +) -> None: + assert simple_cidict.original_key(lookup_key) == expected_original + + +def test_normalized_items(simple_cidict: CIDict[str, int]) -> None: + items = list(simple_cidict.normalized_items()) + assert set(items) == {("alpha", 1), ("beta", 2), ("gamma", 3)} + + +@pytest.mark.parametrize( + "key, expected", + [ + ("Alpha", True), + ("alpha", True), + ("ALPHA", True), + ("missing", False), + ], +) +def test_contains( + simple_cidict: CIDict[str, int], key: str, expected: bool +) -> None: + assert (key in simple_cidict) is expected + + +@pytest.mark.parametrize( + "other, expected", + [ + (CIDict({"alpha": 1, "beta": 2, "gamma": 3}), True), + (CIDict({"ALPHA": 1, "BETA": 2, "GAMMA": 3}), True), + ({"alpha": 1, "beta": 2, "gamma": 3}, True), + ({"ALPHA": 1, "BETA": 2, "GAMMA": 3}, True), + (CIDict({"alpha": 99, "beta": 2, "gamma": 3}), False), + (CIDict({"alpha": 1, "beta": 2}), False), + ("not a mapping", False), + (42, False), + ], +) +def test_eq( + simple_cidict: CIDict[str, int], other: Any, expected: bool +) -> None: + assert (simple_cidict == other) is expected + + +def test_repr(simple_cidict: CIDict[str, int]) -> None: + assert repr(simple_cidict) == "CIDict({'Alpha': 1, 'Beta': 2, 'Gamma': 3})" + + +@pytest.mark.parametrize( + "pop_key, expected_value, remaining_len", + [ + ("Alpha", 1, 2), + ("BETA", 2, 2), + ("gamma", 3, 2), + ], +) +def test_pop( + simple_cidict: CIDict[str, int], + pop_key: str, + expected_value: int, + remaining_len: int, +) -> None: + assert simple_cidict.pop(pop_key) == expected_value + assert len(simple_cidict) == remaining_len + + +def test_setdefault_existing(simple_cidict: CIDict[str, int]) -> None: + assert simple_cidict.setdefault("ALPHA", 999) == 1 + assert simple_cidict["alpha"] == 1 + + +def test_setdefault_missing(simple_cidict: CIDict[str, int]) -> None: + assert simple_cidict.setdefault("Delta", 4) == 4 + assert simple_cidict["delta"] == 4 + + +def test_keys(simple_cidict: CIDict[str, int]) -> None: + assert set(simple_cidict.keys()) == {"Alpha", "Beta", "Gamma"} + + +def test_values(simple_cidict: CIDict[str, int]) -> None: + assert set(simple_cidict.values()) == {1, 2, 3} + + +def test_items(simple_cidict: CIDict[str, int]) -> None: + assert set(simple_cidict.items()) == { + ("Alpha", 1), + ("Beta", 2), + ("Gamma", 3), + }