diff --git a/dreadnode/exporter.py b/dreadnode/exporter.py index 7ef68277..cd0f58bc 100644 --- a/dreadnode/exporter.py +++ b/dreadnode/exporter.py @@ -9,14 +9,19 @@ class CustomOTLPSpanExporter(OTLPSpanExporter): """A custom OTLP exporter that injects our SDK version into the User-Agent.""" def __init__(self, **kwargs: t.Any) -> None: + custom_endpoint = kwargs.pop("custom_endpoint", None) + super().__init__(**kwargs) - # 2. Get the current User-Agent set by OTel (e.g., OTel-OTLP-Exporter-Python/) + if custom_endpoint: + self._endpoint = custom_endpoint + + # Get the current User-Agent set by OTel (e.g., OTel-OTLP-Exporter-Python/) otlp_user_agent = self._session.headers.get("User-Agent") if isinstance(otlp_user_agent, bytes): otlp_user_agent = otlp_user_agent.decode("utf-8") - # 3. Combine the User-Agent strings. + # Combine the User-Agent strings. if otlp_user_agent: combined_user_agent = f"{DEFAULT_USER_AGENT} {otlp_user_agent}" self._session.headers["User-Agent"] = combined_user_agent diff --git a/dreadnode/main.py b/dreadnode/main.py index fd74d89c..844643fa 100644 --- a/dreadnode/main.py +++ b/dreadnode/main.py @@ -629,12 +629,14 @@ def initialize(self) -> None: ) from e headers = {"X-Api-Key": self.token} - endpoint = "/api/otel/traces" + # Use custom_endpoint to bypass OTLP's automatic /v1/traces suffix + custom_endpoint = urljoin(self.server, "/api/otel/traces") span_processors.append( BatchSpanProcessor( RemovePendingSpansExporter( # This will tell Logfire to emit pending spans to us as well CustomOTLPSpanExporter( - endpoint=urljoin(self.server, endpoint), + endpoint=self.server, + custom_endpoint=custom_endpoint, headers=headers, compression=Compression.Gzip, ), diff --git a/dreadnode/task.py b/dreadnode/task.py index e17952aa..cae777b9 100644 --- a/dreadnode/task.py +++ b/dreadnode/task.py @@ -538,6 +538,7 @@ async def run_always(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: # # Log the output + output_object_hash: str | None = None if log_output and ( not isinstance(self.log_inputs, Inherited) or seems_useful_to_serialize(output) ): @@ -546,10 +547,12 @@ async def run_always(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: # output, attributes={"auto": True}, ) - elif run is not None: - # Link the output to the inputs - for input_object_hash in input_object_hashes: - run.link_objects(output_object_hash, input_object_hash) + + if run is not None: + # Link the output to the inputs if we logged it + if output_object_hash is not None: + for input_object_hash in input_object_hashes: + run.link_objects(output_object_hash, input_object_hash) if create_run: run.log_output("output", output, attributes={"auto": True}) diff --git a/tests/test_otel_exporter.py b/tests/test_otel_exporter.py new file mode 100644 index 00000000..11f46fb2 --- /dev/null +++ b/tests/test_otel_exporter.py @@ -0,0 +1,200 @@ +"""Tests for OTLP exporter and Task output linking changes.""" + +from unittest.mock import Mock +from urllib.parse import urljoin + +import pytest + + +class TestCustomOTLPSpanExporterLogic: + """Test CustomOTLPSpanExporter logic for custom endpoint and User-Agent injection.""" + + def test_custom_endpoint_is_extracted_from_kwargs(self): + """Test that custom_endpoint is extracted before passing to parent.""" + test_kwargs = { + "endpoint": "https://example.com", + "custom_endpoint": "https://example.com/api/otel/traces", + "headers": {"X-Api-Key": "test-key"}, + } + + # Simulate: custom_endpoint = kwargs.pop("custom_endpoint", None) + custom_endpoint = test_kwargs.pop("custom_endpoint", None) + + assert custom_endpoint == "https://example.com/api/otel/traces" + assert "custom_endpoint" not in test_kwargs + assert "endpoint" in test_kwargs + assert "headers" in test_kwargs + + def test_user_agent_combination_with_string(self): + """Test User-Agent combination logic with string input.""" + default_user_agent = "dreadnode/1.0.0" + otlp_user_agent = "OTel-OTLP-Exporter-Python/1.0.0" + + # Simulate the combination logic from exporter.py + if isinstance(otlp_user_agent, bytes): + otlp_user_agent = otlp_user_agent.decode("utf-8") + + if otlp_user_agent: + combined_user_agent = f"{default_user_agent} {otlp_user_agent}" + + assert default_user_agent in combined_user_agent + assert "OTel-OTLP-Exporter-Python/1.0.0" in combined_user_agent + assert combined_user_agent.startswith(default_user_agent) + + def test_user_agent_combination_with_bytes(self): + """Test User-Agent combination logic with bytes input.""" + default_user_agent = "dreadnode/1.0.0" + otlp_user_agent = b"OTel-OTLP-Exporter-Python/1.0.0" + + # Simulate the combination logic + if isinstance(otlp_user_agent, bytes): + otlp_user_agent = otlp_user_agent.decode("utf-8") + + if otlp_user_agent: + combined_user_agent = f"{default_user_agent} {otlp_user_agent}" + + assert isinstance(combined_user_agent, str) + assert default_user_agent in combined_user_agent + assert "OTel-OTLP-Exporter-Python/1.0.0" in combined_user_agent + + def test_user_agent_fallback_when_none(self): + """Test User-Agent fallback when no OTLP User-Agent exists.""" + default_user_agent = "dreadnode/1.0.0" + otlp_user_agent = None + + # Simulate the fallback logic + if isinstance(otlp_user_agent, bytes): + otlp_user_agent = otlp_user_agent.decode("utf-8") + + if otlp_user_agent: + combined_user_agent = f"{default_user_agent} {otlp_user_agent}" + else: + combined_user_agent = default_user_agent + + assert combined_user_agent == default_user_agent + + def test_custom_endpoint_override_logic(self): + """Test the custom_endpoint override logic.""" + mock_exporter = Mock() + mock_exporter._endpoint = "https://example.com/v1/traces" # Default OTLP + + custom_endpoint = "https://example.com/api/otel/traces" + + # Simulate: if custom_endpoint: self._endpoint = custom_endpoint + if custom_endpoint: + mock_exporter._endpoint = custom_endpoint + + assert mock_exporter._endpoint == custom_endpoint + assert mock_exporter._endpoint != "https://example.com/v1/traces" + + def test_no_custom_endpoint_preserves_default(self): + """Test that no custom_endpoint doesn't override the default.""" + mock_exporter = Mock() + default_endpoint = "https://example.com/v1/traces" + mock_exporter._endpoint = default_endpoint + + custom_endpoint = None + + if custom_endpoint: + mock_exporter._endpoint = custom_endpoint + + assert mock_exporter._endpoint == default_endpoint + + +class TestDreadnodeExporterConfiguration: + """Test Dreadnode exporter configuration in main.py.""" + + def test_custom_endpoint_construction(self): + """Test that custom endpoint is constructed correctly with urljoin.""" + server = "https://platform.example.com" + custom_endpoint = urljoin(server, "/api/otel/traces") + + assert custom_endpoint == "https://platform.example.com/api/otel/traces" + + def test_custom_endpoint_with_trailing_slash(self): + """Test custom endpoint construction with trailing slash in server URL.""" + server = "https://platform.example.com/" + custom_endpoint = urljoin(server, "/api/otel/traces") + + assert custom_endpoint == "https://platform.example.com/api/otel/traces" + + def test_endpoint_and_custom_endpoint_are_different(self): + """Test that endpoint and custom_endpoint parameters are different.""" + server = "https://platform.example.com" + + endpoint = server + custom_endpoint = urljoin(server, "/api/otel/traces") + + assert endpoint != custom_endpoint + assert custom_endpoint.endswith("/api/otel/traces") + assert not custom_endpoint.endswith("/v1/traces") + assert "/v1/traces" not in custom_endpoint + + +class TestTaskOutputHashBugFix: + """Test the Task output_object_hash initialization bug fix (dreadnode/task.py:541).""" + + def test_output_object_hash_initialized_to_none(self): + """Test that output_object_hash is initialized before conditional (prevents UnboundLocalError).""" + # Simulate the fix: output_object_hash = None + output_object_hash = None + + # This should not raise UnboundLocalError + try: + if output_object_hash is not None: + pass # Would call link_objects here + assert True + except UnboundLocalError: + pytest.fail("Should not raise UnboundLocalError after fix") + + def test_linking_only_when_hash_exists(self): + """Test that linking logic only executes when hash is not None.""" + output_object_hash = None + link_called = False + + # Simulate: if output_object_hash is not None: run.link_objects(...) + if output_object_hash is not None: + link_called = True + + assert not link_called + + def test_linking_when_hash_exists(self): + """Test that linking logic executes when hash exists.""" + output_object_hash = "some_hash_value" + link_called = False + + if output_object_hash is not None: + link_called = True + + assert link_called + + def test_multiple_inputs_linked_to_single_output(self): + """Test logic for linking multiple input hashes to one output hash.""" + output_object_hash = "output_123" + input_object_hashes = ["input_1", "input_2", "input_3"] + links = [] + + # Simulate: for input_object_hash in input_object_hashes: + # run.link_objects(output_object_hash, input_object_hash) + if output_object_hash is not None: + for input_object_hash in input_object_hashes: + links.append((output_object_hash, input_object_hash)) + + assert len(links) == 3 + for output_hash, input_hash in links: + assert output_hash == "output_123" + assert input_hash in input_object_hashes + + def test_no_linking_when_output_hash_is_none(self): + """Test that no linking occurs when output_object_hash is None (output not logged).""" + output_object_hash = None # This is what the fix ensures is initialized + input_object_hashes = ["input_1", "input_2"] + links = [] + + # Simulate the linking loop with the None check + if output_object_hash is not None: + for input_object_hash in input_object_hashes: + links.append((output_object_hash, input_object_hash)) + + # Should not create any links when output_hash is None + assert len(links) == 0