From 078bbff73db9395bd8543d3e7342c9a036636ed6 Mon Sep 17 00:00:00 2001 From: JY Tan Date: Wed, 7 Jan 2026 15:59:50 -0800 Subject: [PATCH 1/3] Commit --- .github/workflows/ci.yml | 101 ++++ CONTRIBUTING.md | 6 +- drift/core/communication/communicator.py | 10 +- drift/core/communication/types.py | 65 ++- drift/core/mock_utils.py | 7 +- drift/core/span_serialization.py | 10 +- drift/core/trace_blocking_manager.py | 1 + drift/core/tracing/adapters/api.py | 38 +- drift/core/tracing/adapters/base.py | 2 +- drift/core/tracing/adapters/filesystem.py | 5 +- drift/core/tracing/otel_converter.py | 35 +- drift/core/tracing/span_utils.py | 4 +- drift/core/tracing/td_span_processor.py | 3 +- .../django/e2e-tests/entrypoint.py | 2 +- .../django/e2e-tests/src/settings.py | 1 - .../django/e2e-tests/src/test_requests.py | 15 +- .../django/e2e-tests/src/urls.py | 10 +- .../django/e2e-tests/src/views.py | 23 +- drift/instrumentation/e2e_common/__init__.py | 1 - .../fastapi/e2e-tests/entrypoint.py | 2 +- .../flask/e2e-tests/entrypoint.py | 1 + .../psycopg/e2e-tests/entrypoint.py | 2 +- .../psycopg/instrumentation.py | 7 +- .../psycopg2/e2e-tests/entrypoint.py | 2 +- .../psycopg2/instrumentation.py | 15 +- drift/instrumentation/registry.py | 2 +- .../requests/instrumentation.py | 3 + .../instrumentation/socket/instrumentation.py | 11 +- drift/instrumentation/wsgi/handler.py | 7 +- drift/instrumentation/wsgi/instrumentation.py | 9 +- pyproject.toml | 17 + tests/__init__.py | 1 + tests/integration/__init__.py | 1 + tests/integration/test_fastapi_replay.py | 14 +- tests/integration/test_flask_basic.py | 10 +- tests/integration/test_flask_replay.py | 27 +- tests/test_transform_engine.py | 12 +- tests/unit/test_adapters.py | 11 +- tests/unit/test_config_loading.py | 29 +- tests/unit/test_context_propagation.py | 15 +- tests/unit/test_data_normalization.py | 137 +++-- tests/unit/test_error_resilience.py | 82 +-- tests/unit/test_json_schema_helper.py | 58 +- tests/unit/test_psycopg_instrumentation.py | 278 --------- tests/unit/test_requests_instrumentation.py | 529 +----------------- tests/unit/test_span_serialization.py | 17 +- tests/unit/test_wsgi_utilities.py | 78 +-- tests/utils/__init__.py | 4 +- tests/utils/fastapi_test_server.py | 9 +- tests/utils/flask_test_server.py | 9 +- tests/utils/test_helpers.py | 2 +- 51 files changed, 558 insertions(+), 1182 deletions(-) create mode 100644 .github/workflows/ci.yml create mode 100644 tests/__init__.py create mode 100644 tests/integration/__init__.py delete mode 100644 tests/unit/test_psycopg_instrumentation.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..4b92f4b --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,101 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + lint: + name: Lint & Format + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Install uv + uses: astral-sh/setup-uv@v4 + with: + version: "latest" + + - name: Setup Python + run: uv python install 3.12 + + - name: Install dependencies + run: uv sync --all-extras + + - name: Check formatting + run: uv run ruff format --check drift/ tests/ + + - name: Lint + run: uv run ruff check drift/ tests/ + + typecheck: + name: Type Check + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Install uv + uses: astral-sh/setup-uv@v4 + with: + version: "latest" + + - name: Setup Python + run: uv python install 3.12 + + - name: Install dependencies + run: uv sync --all-extras + + - name: Type check + run: uv run ty check drift/ tests/ + + test: + name: Unit Tests + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Install uv + uses: astral-sh/setup-uv@v4 + with: + version: "latest" + + - name: Setup Python + run: uv python install 3.12 + + - name: Install dependencies + run: uv sync --all-extras + + - name: Run unit tests + run: uv run pytest tests/unit/ -v + + build: + name: Build Package + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Install uv + uses: astral-sh/setup-uv@v4 + with: + version: "latest" + + - name: Setup Python + run: uv python install 3.12 + + - name: Install dependencies + run: uv sync --all-extras + + - name: Build package + run: uv build + + - name: Verify package can be installed + run: | + uv pip install dist/*.whl --system + python -c "import drift; print('Package imported successfully')" + diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 6a94cde..d96298d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -21,9 +21,9 @@ uv sync --all-extras This project uses [ruff](https://docs.astral.sh/ruff/) for linting/formatting and [ty](https://github.com/astral-sh/ty) for type checking. ```bash -uv run ruff check drift/ --fix # Lint and auto-fix -uv run ruff format drift/ # Format -uv run ty check drift/ # Type check +uv run ruff check drift/ tests/ --fix # Lint and auto-fix +uv run ruff format drift/ tests/ # Format +uv run ty check drift/ tests/ # Type check ``` ## Running Tests diff --git a/drift/core/communication/communicator.py b/drift/core/communication/communicator.py index 656d30f..ab0adea 100644 --- a/drift/core/communication/communicator.py +++ b/drift/core/communication/communicator.py @@ -11,8 +11,9 @@ from typing import Any from tusk.drift.core.v1 import GetMockRequest as ProtoGetMockRequest -from ..span_serialization import clean_span_to_proto + from ...version import MIN_CLI_VERSION, SDK_VERSION +from ..span_serialization import clean_span_to_proto from ..types import CleanSpanData, calling_library_context from .types import ( CliMessage, @@ -566,7 +567,7 @@ async def _receive_response(self, request_id: str) -> MockResponseOutput: response = cli_message.connect_response if response.success: logger.debug("CLI acknowledged connection") - self._session_id = response.session_id + # Note: session_id is not in the protobuf schema else: logger.error(f"CLI rejected connection: {response.error}") continue @@ -576,6 +577,8 @@ async def _receive_response(self, request_id: str) -> MockResponseOutput: def _recv_exact(self, n: int) -> bytes | None: """Receive exactly n bytes from socket.""" + if self._socket is None: + return None data = bytearray() while len(data) < n: chunk = self._socket.recv(n - len(data)) @@ -655,7 +658,8 @@ def _handle_cli_message(self, message: CliMessage) -> MockResponseOutput: if message.connect_response: response = message.connect_response if response.success: - self._session_id = response.session_id + logger.debug("CLI acknowledged connection") + # Note: session_id is not in the protobuf schema return MockResponseOutput(found=False, error="Unexpected connect response") return MockResponseOutput(found=False, error="Unknown message type") diff --git a/drift/core/communication/types.py b/drift/core/communication/types.py index aee3610..99bd8ad 100644 --- a/drift/core/communication/types.py +++ b/drift/core/communication/types.py @@ -77,6 +77,43 @@ CLIMessageType = MessageType +def _python_to_value(value: Any) -> Any: + """Convert Python value to protobuf Value.""" + from betterproto.lib.google.protobuf import ListValue, Value + + if value is None: + from betterproto.lib.google.protobuf import NullValue + + return Value(null_value=NullValue.NULL_VALUE) # type: ignore[arg-type] + elif isinstance(value, bool): + return Value(bool_value=value) + elif isinstance(value, (int, float)): + return Value(number_value=float(value)) + elif isinstance(value, str): + return Value(string_value=value) + elif isinstance(value, dict): + from betterproto.lib.google.protobuf import Struct + + struct = Struct() + struct.fields = {k: _python_to_value(v) for k, v in value.items()} + return Value(struct_value=struct) + elif isinstance(value, (list, tuple)): + list_val = ListValue(values=[_python_to_value(item) for item in value]) + return Value(list_value=list_val) + else: + return Value(string_value=str(value)) + + +def _dict_to_struct(data: dict[str, Any]) -> Any: + """Convert Python dict to protobuf Struct.""" + from betterproto.lib.google.protobuf import Struct + + struct = Struct() + if data: + struct.fields = {k: _python_to_value(v) for k, v in data.items()} + return struct + + @dataclass class ConnectRequest: """Initial connection request from SDK to CLI. @@ -102,11 +139,18 @@ class ConnectRequest: def to_proto(self) -> ProtoConnectRequest: """Convert to protobuf message.""" + from betterproto.lib.google.protobuf import Struct + + # Convert metadata dict to protobuf Struct + metadata_struct = Struct() + if self.metadata: + metadata_struct.fields = {k: _python_to_value(v) for k, v in self.metadata.items()} + return ProtoConnectRequest( service_id=self.service_id, sdk_version=self.sdk_version, min_cli_version=self.min_cli_version, - metadata=self.metadata, + metadata=metadata_struct, ) @@ -129,10 +173,12 @@ class ConnectResponse: @classmethod def from_proto(cls, proto: ProtoConnectResponse) -> ConnectResponse: """Create from protobuf message.""" + # Note: ProtoConnectResponse only has success and error fields + # cli_version and session_id are SDK-only extensions not in the protobuf schema return cls( success=proto.success, - cli_version=proto.cli_version or None, - session_id=proto.session_id or None, + cli_version=None, + session_id=None, error=proto.error or None, ) @@ -164,7 +210,7 @@ class GetMockRequest: def to_proto(self) -> ProtoGetMockRequest: """Convert to protobuf message.""" - span = dict_to_span(self.outbound_span) if self.outbound_span else None + span = dict_to_span(self.outbound_span) if self.outbound_span else ProtoSpan() return ProtoGetMockRequest( request_id=self.request_id, test_id=self.test_id, @@ -253,7 +299,8 @@ def dict_to_span(data: dict[str, Any]) -> ProtoSpan: message=status_data.get("message", ""), ) else: - status = ProtoSpanStatus(code=ProtoStatusCode.UNSET) + # UNSET is not a valid StatusCode in the proto - use UNSPECIFIED (0) + status = ProtoSpanStatus(code=ProtoStatusCode.UNSPECIFIED) return ProtoSpan( trace_id=data.get("trace_id", data.get("traceId", "")), @@ -263,8 +310,8 @@ def dict_to_span(data: dict[str, Any]) -> ProtoSpan: package_name=data.get("package_name", data.get("packageName", "")), instrumentation_name=data.get("instrumentation_name", data.get("instrumentationName", "")), submodule_name=data.get("submodule_name", data.get("submoduleName", "")), - input_value=data.get("input_value", data.get("inputValue", {})), - output_value=data.get("output_value", data.get("outputValue", {})), + input_value=_dict_to_struct(data.get("input_value", data.get("inputValue", {}))), + output_value=_dict_to_struct(data.get("output_value", data.get("outputValue", {}))), input_schema_hash=data.get("input_schema_hash", data.get("inputSchemaHash", "")), output_schema_hash=data.get("output_schema_hash", data.get("outputSchemaHash", "")), input_value_hash=data.get("input_value_hash", data.get("inputValueHash", "")), @@ -306,8 +353,8 @@ def span_to_proto(span: Any) -> ProtoSpan: package_name=span.package_name or "", instrumentation_name=span.instrumentation_name or "", submodule_name=span.submodule_name or "", - input_value=span.input_value or {}, - output_value=span.output_value or {}, + input_value=_dict_to_struct(span.input_value or {}), + output_value=_dict_to_struct(span.output_value or {}), input_schema_hash=span.input_schema_hash or "", output_schema_hash=span.output_schema_hash or "", input_value_hash=span.input_value_hash or "", diff --git a/drift/core/mock_utils.py b/drift/core/mock_utils.py index a985468..ab02a06 100644 --- a/drift/core/mock_utils.py +++ b/drift/core/mock_utils.py @@ -12,11 +12,12 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: + from .communication.types import MockResponseOutput from .drift_sdk import TuskDrift from .json_schema_helper import SchemaMerges - from .types import CleanSpanData, MockResponseOutput + from .types import CleanSpanData -from .json_schema_helper import JsonSchemaHelper +from .json_schema_helper import JsonSchema, JsonSchemaHelper from .types import ( Duration, PackageType, @@ -91,7 +92,7 @@ def convert_mock_request_to_clean_span( input_schema=input_result.schema, input_schema_hash=input_result.decoded_schema_hash, input_value_hash=input_result.decoded_value_hash, - output_schema=None, + output_schema=JsonSchema(), output_schema_hash="", output_value_hash="", kind=kind, diff --git a/drift/core/span_serialization.py b/drift/core/span_serialization.py index 1c6ba06..05eb6fb 100644 --- a/drift/core/span_serialization.py +++ b/drift/core/span_serialization.py @@ -27,7 +27,7 @@ ) from .json_schema_helper import DecodedType, EncodingType, JsonSchema, JsonSchemaType -from .types import CleanSpanData +from .types import CleanSpanData, PackageType def _value_to_proto(value: Any) -> ProtoValue: @@ -37,7 +37,9 @@ def _value_to_proto(value: Any) -> ProtoValue: proto_value = ProtoValue() if value is None: - proto_value.null_value = 0 + from betterproto.lib.google.protobuf import NullValue + + proto_value.null_value = NullValue.NULL_VALUE # type: ignore[assignment] elif isinstance(value, bool): proto_value.bool_value = value elif isinstance(value, (int, float)): @@ -94,7 +96,7 @@ def clean_span_to_proto(span: CleanSpanData) -> ProtoSpan: package_name=span.package_name, instrumentation_name=span.instrumentation_name, submodule_name=span.submodule_name, - package_type=span.package_type.value if span.package_type else 0, + package_type=span.package_type.value if span.package_type else PackageType.UNSPECIFIED.value, # type: ignore[arg-type] environment=span.environment, kind=span.kind.value if hasattr(span.kind, "value") else span.kind, input_value=_dict_to_struct(span.input_value), @@ -119,7 +121,7 @@ def clean_span_to_proto(span: CleanSpanData) -> ProtoSpan: seconds=span.duration.seconds, microseconds=span.duration.nanos // 1000, ), - metadata=_metadata_to_dict(span.metadata), + metadata=_dict_to_struct(_metadata_to_dict(span.metadata)), ) diff --git a/drift/core/trace_blocking_manager.py b/drift/core/trace_blocking_manager.py index 359fbe9..6c46eba 100644 --- a/drift/core/trace_blocking_manager.py +++ b/drift/core/trace_blocking_manager.py @@ -59,6 +59,7 @@ def get_instance(cls) -> TraceBlockingManager: if cls._instance is None: cls._instance = TraceBlockingManager() cls._instance._start_cleanup_thread() + assert cls._instance is not None return cls._instance def block_trace(self, trace_id: str, reason: str = "size_limit") -> None: diff --git a/drift/core/tracing/adapters/api.py b/drift/core/tracing/adapters/api.py index b626d5d..727e2f0 100644 --- a/drift/core/tracing/adapters/api.py +++ b/drift/core/tracing/adapters/api.py @@ -136,21 +136,19 @@ def _transform_span_to_protobuf(self, clean_span: CleanSpanData) -> Any: microseconds=clean_span.duration.nanos // 1000, ) - metadata_struct = None + metadata_struct = _dict_to_struct({}) if clean_span.metadata is not None: - if hasattr(clean_span.metadata, "__dataclass_fields__"): - from dataclasses import asdict - - metadata_dict = asdict(clean_span.metadata) - else: - metadata_dict = clean_span.metadata if isinstance(clean_span.metadata, dict) else {} + metadata_dict = clean_span.metadata if isinstance(clean_span.metadata, dict) else {} metadata_struct = _dict_to_struct(metadata_dict) - package_type_value = ( - clean_span.package_type.value - if hasattr(clean_span.package_type, "value") - else (clean_span.package_type or 0) - ) + from tusk.drift.core.v1 import PackageType as ProtoPackageType + + from ...types import PackageType as SDKPackageType + + if clean_span.package_type and hasattr(clean_span.package_type, "value"): + package_type_value = ProtoPackageType(clean_span.package_type.value) + else: + package_type_value = ProtoPackageType(SDKPackageType.UNSPECIFIED.value) from tusk.drift.core.v1 import SpanStatus as ProtoSpanStatus @@ -187,22 +185,20 @@ def convert_json_schema(sdk_schema: Any) -> Any: type_value = sdk_schema.type.value if hasattr(sdk_schema.type, "value") else sdk_schema.type encoding_value = ( - sdk_schema.encoding.value - if sdk_schema.encoding and hasattr(sdk_schema.encoding, "value") - else (sdk_schema.encoding or 0) + sdk_schema.encoding.value if sdk_schema.encoding and hasattr(sdk_schema.encoding, "value") else None ) decoded_type_value = ( sdk_schema.decoded_type.value if sdk_schema.decoded_type and hasattr(sdk_schema.decoded_type, "value") - else (sdk_schema.decoded_type or 0) + else None ) return ProtoJsonSchema( type=type_value, properties=proto_properties, items=proto_items, - encoding=encoding_value, - decoded_type=decoded_type_value, + encoding=encoding_value, # type: ignore[arg-type] + decoded_type=decoded_type_value, # type: ignore[arg-type] match_importance=sdk_schema.match_importance, ) @@ -232,7 +228,7 @@ def convert_json_schema(sdk_schema: Any) -> Any: timestamp=timestamp, duration=duration, is_root_span=clean_span.is_root_span, - metadata=metadata_struct, + metadata=metadata_struct, # type: ignore[arg-type] ) @@ -243,7 +239,9 @@ def _dict_to_struct(data: dict[str, Any]) -> Struct: def value_to_proto(val: Any) -> Value: """Convert a Python value to protobuf Value.""" if val is None: - return Value(null_value=0) + from betterproto.lib.google.protobuf import NullValue + + return Value(null_value=NullValue.NULL_VALUE) # type: ignore[arg-type] elif isinstance(val, bool): return Value(bool_value=val) elif isinstance(val, (int, float)): diff --git a/drift/core/tracing/adapters/base.py b/drift/core/tracing/adapters/base.py index c09775e..44ddb12 100644 --- a/drift/core/tracing/adapters/base.py +++ b/drift/core/tracing/adapters/base.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from ...core.types import CleanSpanData + from ...types import CleanSpanData class ExportResultCode(Enum): diff --git a/drift/core/tracing/adapters/filesystem.py b/drift/core/tracing/adapters/filesystem.py index d030de1..79375b8 100644 --- a/drift/core/tracing/adapters/filesystem.py +++ b/drift/core/tracing/adapters/filesystem.py @@ -169,9 +169,8 @@ def _span_to_dict(self, span: CleanSpanData) -> dict[str, Any]: result["environment"] = span.environment if span.metadata is not None: - result["metadata"] = ( - asdict(span.metadata) if hasattr(span.metadata, "__dataclass_fields__") else span.metadata - ) + # metadata is dict[str, Any], so just use it directly + result["metadata"] = span.metadata if span.transform_metadata is not None: result["transformMetadata"] = asdict(span.transform_metadata) diff --git a/drift/core/tracing/otel_converter.py b/drift/core/tracing/otel_converter.py index 5f0a688..f0eea51 100644 --- a/drift/core/tracing/otel_converter.py +++ b/drift/core/tracing/otel_converter.py @@ -24,6 +24,8 @@ SpanStatus, StatusCode, Timestamp, + TransformAction, + TransformMetadata, ) from .td_attributes import TdSpanAttributes @@ -33,6 +35,30 @@ logger = logging.getLogger(__name__) +def _dict_to_transform_metadata(data: dict | None) -> TransformMetadata | None: + """Convert a dictionary to TransformMetadata.""" + if data is None: + return None + + actions = [] + if "actions" in data and isinstance(data["actions"], list): + for action_dict in data["actions"]: + if isinstance(action_dict, dict): + actions.append( + TransformAction( + type=action_dict.get("type", "redact"), # type: ignore[arg-type] + field=action_dict.get("field", ""), + reason=action_dict.get("reason", ""), + description=action_dict.get("description"), + ) + ) + + return TransformMetadata( + transformed=data.get("transformed", False), + actions=actions, + ) + + def format_trace_id(trace_id: int) -> str: """Format OpenTelemetry trace ID (int) to hex string.""" return format(trace_id, "032x") @@ -331,12 +357,15 @@ def otel_span_to_clean_span_data( # Extract metadata metadata = get_attribute_as_dict(attributes, TdSpanAttributes.METADATA) - transform_metadata = get_attribute_as_dict(attributes, TdSpanAttributes.TRANSFORM_METADATA) + transform_metadata_dict = get_attribute_as_dict(attributes, TdSpanAttributes.TRANSFORM_METADATA) + transform_metadata = _dict_to_transform_metadata(transform_metadata_dict) stack_trace = get_attribute_as_str(attributes, TdSpanAttributes.STACK_TRACE) # Convert timing - timestamp = ns_to_timestamp(otel_span.start_time) - duration_ns = otel_span.end_time - otel_span.start_time if otel_span.end_time else 0 + start_time = otel_span.start_time or 0 + timestamp = ns_to_timestamp(start_time) + end_time = otel_span.end_time or 0 + duration_ns = end_time - start_time if end_time else 0 duration = ns_to_duration(duration_ns) # Convert status diff --git a/drift/core/tracing/span_utils.py b/drift/core/tracing/span_utils.py index 90a7863..387897d 100644 --- a/drift/core/tracing/span_utils.py +++ b/drift/core/tracing/span_utils.py @@ -145,7 +145,7 @@ def create_span(options: CreateSpanOptions) -> SpanInfo | None: # Check if we should block span creation for this trace # (This matches the trace blocking check in Node.js SDK) - active_span = trace.get_span(parent_context) + active_span = trace.get_current_span(parent_context) if active_span and active_span.is_recording(): from ..trace_blocking_manager import TraceBlockingManager @@ -172,7 +172,7 @@ def create_span(options: CreateSpanOptions) -> SpanInfo | None: span_id = format_span_id(span_context.span_id) # Create new context with span active - new_context = trace.set_span(parent_context, span) + new_context = trace.set_span_in_context(span, parent_context) # Store is_pre_app_start in context (matches Node.js SDK pattern) # We'll use span attributes for this instead of context variables diff --git a/drift/core/tracing/td_span_processor.py b/drift/core/tracing/td_span_processor.py index 002d20e..ce4887f 100644 --- a/drift/core/tracing/td_span_processor.py +++ b/drift/core/tracing/td_span_processor.py @@ -15,7 +15,8 @@ from ..sampling import should_sample from ..trace_blocking_manager import TraceBlockingManager, should_block_span -from ..types import TD_INSTRUMENTATION_LIBRARY_NAME, TuskDriftMode, SpanKind as TdSpanKind, replay_trace_id_context +from ..types import TD_INSTRUMENTATION_LIBRARY_NAME, TuskDriftMode, replay_trace_id_context +from ..types import SpanKind as TdSpanKind from .otel_converter import otel_span_to_clean_span_data if TYPE_CHECKING: diff --git a/drift/instrumentation/django/e2e-tests/entrypoint.py b/drift/instrumentation/django/e2e-tests/entrypoint.py index 8f2e248..d8b023c 100644 --- a/drift/instrumentation/django/e2e-tests/entrypoint.py +++ b/drift/instrumentation/django/e2e-tests/entrypoint.py @@ -23,6 +23,7 @@ class DjangoE2ETestRunner(E2ETestRunnerBase): def __init__(self): import os + port = int(os.getenv("PORT", "8000")) super().__init__(app_port=port) @@ -31,4 +32,3 @@ def __init__(self): runner = DjangoE2ETestRunner() exit_code = runner.run() sys.exit(exit_code) - diff --git a/drift/instrumentation/django/e2e-tests/src/settings.py b/drift/instrumentation/django/e2e-tests/src/settings.py index 30bff6d..7e90072 100644 --- a/drift/instrumentation/django/e2e-tests/src/settings.py +++ b/drift/instrumentation/django/e2e-tests/src/settings.py @@ -53,4 +53,3 @@ "level": "INFO", }, } - diff --git a/drift/instrumentation/django/e2e-tests/src/test_requests.py b/drift/instrumentation/django/e2e-tests/src/test_requests.py index 35dfb14..4e0301a 100644 --- a/drift/instrumentation/django/e2e-tests/src/test_requests.py +++ b/drift/instrumentation/django/e2e-tests/src/test_requests.py @@ -31,12 +31,15 @@ def make_request(method: str, endpoint: str, **kwargs): make_request("GET", "/api/user/test123") make_request("GET", "/api/activity") make_request("GET", "/api/post/1") - make_request("POST", "/api/post", json={ - "title": "Test Post", - "body": "This is a test post body", - "userId": 1, - }) + make_request( + "POST", + "/api/post", + json={ + "title": "Test Post", + "body": "This is a test post body", + "userId": 1, + }, + ) make_request("DELETE", "/api/post/1/delete") print("\nAll requests completed successfully") - diff --git a/drift/instrumentation/django/e2e-tests/src/urls.py b/drift/instrumentation/django/e2e-tests/src/urls.py index 30778c7..2a75b87 100644 --- a/drift/instrumentation/django/e2e-tests/src/urls.py +++ b/drift/instrumentation/django/e2e-tests/src/urls.py @@ -1,15 +1,14 @@ """URL configuration for Django e2e test application.""" from django.urls import path - from views import ( - health, - get_weather, - get_user, create_post, - get_post, delete_post, get_activity, + get_post, + get_user, + get_weather, + health, ) urlpatterns = [ @@ -21,4 +20,3 @@ path("api/post//delete", delete_post, name="delete_post"), path("api/activity", get_activity, name="get_activity"), ] - diff --git a/drift/instrumentation/django/e2e-tests/src/views.py b/drift/instrumentation/django/e2e-tests/src/views.py index d4b6bd6..c2a3995 100644 --- a/drift/instrumentation/django/e2e-tests/src/views.py +++ b/drift/instrumentation/django/e2e-tests/src/views.py @@ -6,7 +6,7 @@ import requests from django.http import JsonResponse from django.views.decorators.csrf import csrf_exempt -from django.views.decorators.http import require_GET, require_POST, require_http_methods +from django.views.decorators.http import require_GET, require_http_methods, require_POST from opentelemetry import context as otel_context @@ -39,10 +39,12 @@ def get_weather(request): ) weather = response.json() - return JsonResponse({ - "location": "New York", - "weather": weather.get("current_weather", {}), - }) + return JsonResponse( + { + "location": "New York", + "weather": weather.get("current_weather", {}), + } + ) except Exception as e: return JsonResponse({"error": f"Failed to fetch weather: {str(e)}"}, status=500) @@ -98,10 +100,12 @@ def get_post(request, post_id: int): post_response = post_future.result() comments_response = comments_future.result() - return JsonResponse({ - "post": post_response.json(), - "comments": comments_response.json(), - }) + return JsonResponse( + { + "post": post_response.json(), + "comments": comments_response.json(), + } + ) @csrf_exempt @@ -123,4 +127,3 @@ def get_activity(request): return JsonResponse(response.json()) except Exception as e: return JsonResponse({"error": f"Failed to fetch activity: {str(e)}"}, status=500) - diff --git a/drift/instrumentation/e2e_common/__init__.py b/drift/instrumentation/e2e_common/__init__.py index ed2147e..2585a76 100644 --- a/drift/instrumentation/e2e_common/__init__.py +++ b/drift/instrumentation/e2e_common/__init__.py @@ -3,4 +3,3 @@ from .base_runner import Colors, E2ETestRunnerBase __all__ = ["Colors", "E2ETestRunnerBase"] - diff --git a/drift/instrumentation/fastapi/e2e-tests/entrypoint.py b/drift/instrumentation/fastapi/e2e-tests/entrypoint.py index e19a081..1cbde7a 100644 --- a/drift/instrumentation/fastapi/e2e-tests/entrypoint.py +++ b/drift/instrumentation/fastapi/e2e-tests/entrypoint.py @@ -23,6 +23,7 @@ class FastAPIE2ETestRunner(E2ETestRunnerBase): def __init__(self): import os + port = int(os.getenv("PORT", "8000")) super().__init__(app_port=port) @@ -31,4 +32,3 @@ def __init__(self): runner = FastAPIE2ETestRunner() exit_code = runner.run() sys.exit(exit_code) - diff --git a/drift/instrumentation/flask/e2e-tests/entrypoint.py b/drift/instrumentation/flask/e2e-tests/entrypoint.py index 8bbead9..7bbc590 100644 --- a/drift/instrumentation/flask/e2e-tests/entrypoint.py +++ b/drift/instrumentation/flask/e2e-tests/entrypoint.py @@ -23,6 +23,7 @@ class FlaskE2ETestRunner(E2ETestRunnerBase): def __init__(self): import os + port = int(os.getenv("PORT", "8000")) super().__init__(app_port=port) diff --git a/drift/instrumentation/psycopg/e2e-tests/entrypoint.py b/drift/instrumentation/psycopg/e2e-tests/entrypoint.py index 0d45f42..3410435 100755 --- a/drift/instrumentation/psycopg/e2e-tests/entrypoint.py +++ b/drift/instrumentation/psycopg/e2e-tests/entrypoint.py @@ -15,7 +15,7 @@ # Add SDK to path for imports sys.path.insert(0, "/sdk") -from drift.instrumentation.e2e_common.base_runner import E2ETestRunnerBase, Colors +from drift.instrumentation.e2e_common.base_runner import Colors, E2ETestRunnerBase class PsycopgE2ETestRunner(E2ETestRunnerBase): diff --git a/drift/instrumentation/psycopg/instrumentation.py b/drift/instrumentation/psycopg/instrumentation.py index 0b217bc..d87296e 100644 --- a/drift/instrumentation/psycopg/instrumentation.py +++ b/drift/instrumentation/psycopg/instrumentation.py @@ -134,8 +134,9 @@ def noop_executemany(q, ps, **kw): return instrumentation._traced_executemany(cursor, noop_executemany, sdk, query, params_seq, **kwargs) - cursor.execute = mock_execute - cursor.executemany = mock_executemany + # Monkey-patch mock functions onto cursor + cursor.execute = mock_execute # type: ignore[method-assign] + cursor.executemany = mock_executemany # type: ignore[method-assign] logger.debug("[MOCK_CONNECTION] Created cursor (psycopg3)") return cursor @@ -275,7 +276,7 @@ def patched_connect(*args, **kwargs): logger.debug("[PATCHED_CONNECT] RECORD mode: Connected to database (psycopg3)") return connection - module.connect = patched_connect # pyright: ignore[reportAttributeAccessIssue] + module.connect = patched_connect # type: ignore[attr-defined] logger.debug("psycopg.connect instrumented") def _create_cursor_factory(self, sdk: TuskDrift, base_factory=None): diff --git a/drift/instrumentation/psycopg2/e2e-tests/entrypoint.py b/drift/instrumentation/psycopg2/e2e-tests/entrypoint.py index 1fa81af..4d7e5bb 100755 --- a/drift/instrumentation/psycopg2/e2e-tests/entrypoint.py +++ b/drift/instrumentation/psycopg2/e2e-tests/entrypoint.py @@ -15,7 +15,7 @@ # Add SDK to path for imports sys.path.insert(0, "/sdk") -from drift.instrumentation.e2e_common.base_runner import E2ETestRunnerBase, Colors +from drift.instrumentation.e2e_common.base_runner import Colors, E2ETestRunnerBase class Psycopg2E2ETestRunner(E2ETestRunnerBase): diff --git a/drift/instrumentation/psycopg2/instrumentation.py b/drift/instrumentation/psycopg2/instrumentation.py index d94c3fb..9042115 100644 --- a/drift/instrumentation/psycopg2/instrumentation.py +++ b/drift/instrumentation/psycopg2/instrumentation.py @@ -22,7 +22,7 @@ from ...core.communication.types import MockRequestInput from ...core.drift_sdk import TuskDrift -from ...core.json_schema_helper import JsonSchemaHelper +from ...core.json_schema_helper import JsonSchema, JsonSchemaHelper from ...core.tracing import TdSpanAttributes from ...core.types import ( CleanSpanData, @@ -108,8 +108,9 @@ def noop_executemany(q, vl): return instrumentation._traced_executemany(cursor, noop_executemany, sdk, query, vars_list) - cursor.execute = mock_execute - cursor.executemany = mock_executemany + # Monkey-patch mock functions onto cursor + cursor.execute = mock_execute # type: ignore[method-assign] + cursor.executemany = mock_executemany # type: ignore[method-assign] logger.debug("[MOCK_CONNECTION] Created cursor") return cursor @@ -317,8 +318,8 @@ def patched_connect(*args, **kwargs): return connection # Apply patch - module.connect = patched_connect # pyright: ignore[reportAttributeAccessIssue] - logger.info(f"psycopg2.connect instrumented. module.connect is now: {module.connect}") + module.connect = patched_connect # type: ignore[attr-defined] + logger.info(f"psycopg2.connect instrumented. module.connect is now: {getattr(module, 'connect', None)}") # Also verify it's actually patched import psycopg2 @@ -674,8 +675,8 @@ def _try_get_mock( submodule_name="query", input_value=input_value, output_value=None, - input_schema=None, # pyright: ignore[reportArgumentType] - output_schema=None, # pyright: ignore[reportArgumentType] + input_schema=JsonSchema(), + output_schema=JsonSchema(), input_schema_hash=input_result.decoded_schema_hash, output_schema_hash="", input_value_hash=input_result.decoded_value_hash, diff --git a/drift/instrumentation/registry.py b/drift/instrumentation/registry.py index 82e7d6e..6cbd768 100644 --- a/drift/instrumentation/registry.py +++ b/drift/instrumentation/registry.py @@ -88,7 +88,7 @@ def _apply_patch(module: ModuleType, patch_fn: PatchFn) -> None: return patch_fn(module) - module.__drift_patched__ = True # pyright: ignore[reportAttributeAccessIssue] + module.__drift_patched__ = True # type: ignore[attr-defined] from typing import TypeVar diff --git a/drift/instrumentation/requests/instrumentation.py b/drift/instrumentation/requests/instrumentation.py index 6240780..31e320d 100644 --- a/drift/instrumentation/requests/instrumentation.py +++ b/drift/instrumentation/requests/instrumentation.py @@ -348,6 +348,9 @@ def _try_get_mock( return None # Create mocked response object + if mock_response_output.response is None: + logger.debug(f"Mock found but response data is None for {method} {url}") + return None return self._create_mock_response(mock_response_output.response, url) except Exception as e: diff --git a/drift/instrumentation/socket/instrumentation.py b/drift/instrumentation/socket/instrumentation.py index bc9230d..7542426 100644 --- a/drift/instrumentation/socket/instrumentation.py +++ b/drift/instrumentation/socket/instrumentation.py @@ -213,12 +213,13 @@ def _send_alert_async(self, stack_trace: str, trace_test_server_span_id: str) -> def _send_in_thread() -> None: try: - asyncio.run( - sdk.communicator.send_unpatched_dependency_alert( - stack_trace=stack_trace, - trace_test_server_span_id=trace_test_server_span_id, + if sdk.communicator is not None: + asyncio.run( + sdk.communicator.send_unpatched_dependency_alert( + stack_trace=stack_trace, + trace_test_server_span_id=trace_test_server_span_id, + ) ) - ) except Exception: pass # Fire-and-forget, ignore errors diff --git a/drift/instrumentation/wsgi/handler.py b/drift/instrumentation/wsgi/handler.py index 85c6b8f..19d0969 100644 --- a/drift/instrumentation/wsgi/handler.py +++ b/drift/instrumentation/wsgi/handler.py @@ -20,10 +20,15 @@ logger = logging.getLogger(__name__) if TYPE_CHECKING: + from collections.abc import Callable + from _typeshed import OptExcInfo from _typeshed.wsgi import StartResponse, WSGIApplication, WSGIEnvironment from opentelemetry.trace import Span + # Type for unbound WSGI method like Flask.wsgi_app that takes (self, environ, start_response) + WsgiAppMethod = Callable[[WSGIApplication, WSGIEnvironment, StartResponse], "Iterable[bytes]"] + from ...core.tracing import TdSpanAttributes from ...core.types import ( @@ -50,7 +55,7 @@ def handle_wsgi_request( app: WSGIApplication, environ: WSGIEnvironment, start_response: StartResponse, - original_wsgi_app: WSGIApplication, + original_wsgi_app: WsgiAppMethod, framework_name: str = "wsgi", instrumentation_name: str | None = None, transform_engine: HttpTransformEngine | None = None, diff --git a/drift/instrumentation/wsgi/instrumentation.py b/drift/instrumentation/wsgi/instrumentation.py index b5eeba4..e4a1edc 100644 --- a/drift/instrumentation/wsgi/instrumentation.py +++ b/drift/instrumentation/wsgi/instrumentation.py @@ -126,12 +126,19 @@ def wrap_wsgi_app(self, wsgi_app: WSGIApplication) -> WSGIApplication: framework_name = self._framework_name instrumentation_name = self.name + # Create a wrapper that matches the WsgiAppMethod signature (app, environ, start_response) + # This allows handle_wsgi_request to work with both Flask-like unbound methods + # and plain WSGI apps + def wsgi_app_method(app, environ, start_response): # type: ignore[no-untyped-def] + # Ignore the app parameter and call the original WSGI app directly + return wsgi_app(environ, start_response) + def instrumented_wsgi_app(environ, start_response): return handle_wsgi_request( wsgi_app, environ, start_response, - wsgi_app, + wsgi_app_method, framework_name=framework_name, instrumentation_name=instrumentation_name, transform_engine=transform_engine, diff --git a/pyproject.toml b/pyproject.toml index a3cda0b..94874d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -110,3 +110,20 @@ line-ending = "auto" [tool.ty.environment] python-version = "3.12" + +[tool.ty.src] +# Exclude e2e-tests directories from type checking +exclude = ["**/e2e-tests/**"] + +[[tool.ty.overrides]] +# Disable unresolved-import errors for instrumentation files with optional dependencies +include = [ + "drift/instrumentation/django/**", + "drift/instrumentation/psycopg/**", + "drift/instrumentation/psycopg2/**", + "drift/instrumentation/redis/**", + "drift/instrumentation/http/transform_engine.py", +] + +[tool.ty.overrides.rules] +"unresolved-import" = "ignore" diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..081d453 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Drift Python SDK test suite.""" diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..4391066 --- /dev/null +++ b/tests/integration/__init__.py @@ -0,0 +1 @@ +"""Drift Python SDK integration tests.""" diff --git a/tests/integration/test_fastapi_replay.py b/tests/integration/test_fastapi_replay.py index bd17cea..c3265e5 100644 --- a/tests/integration/test_fastapi_replay.py +++ b/tests/integration/test_fastapi_replay.py @@ -28,8 +28,8 @@ import requests from drift import TuskDrift -from drift.core.types import SpanKind from drift.core.tracing.adapters import InMemorySpanAdapter, register_in_memory_adapter +from drift.core.types import SpanKind class TestFastAPIReplayMode(unittest.TestCase): @@ -105,9 +105,7 @@ def test_request_without_trace_id_header(self): def test_request_with_trace_id_header(self): """Test that requests with trace ID create SERVER spans.""" - response = requests.get( - f"{self.base_url}/user/alice", headers={"x-td-trace-id": "test-trace-123"} - ) + response = requests.get(f"{self.base_url}/user/alice", headers={"x-td-trace-id": "test-trace-123"}) self.assertEqual(response.status_code, 200) data = response.json() @@ -151,14 +149,10 @@ def test_post_request_with_trace_id(self): def test_case_insensitive_headers(self): """Test that trace ID header is case-insensitive.""" - response = requests.get( - f"{self.base_url}/health", headers={"x-td-trace-id": "lowercase-trace"} - ) + response = requests.get(f"{self.base_url}/health", headers={"x-td-trace-id": "lowercase-trace"}) self.assertEqual(response.status_code, 200) - response = requests.get( - f"{self.base_url}/health", headers={"X-TD-TRACE-ID": "uppercase-trace"} - ) + response = requests.get(f"{self.base_url}/health", headers={"X-TD-TRACE-ID": "uppercase-trace"}) self.assertEqual(response.status_code, 200) diff --git a/tests/integration/test_flask_basic.py b/tests/integration/test_flask_basic.py index 0c7b944..548a46a 100644 --- a/tests/integration/test_flask_basic.py +++ b/tests/integration/test_flask_basic.py @@ -60,10 +60,12 @@ def error(): @cls.app.route("/headers") def headers(): # Echo back some headers for testing - return jsonify({ - "user_agent": request.headers.get("User-Agent"), - "custom_header": request.headers.get("X-Custom-Header"), - }) + return jsonify( + { + "user_agent": request.headers.get("User-Agent"), + "custom_header": request.headers.get("X-Custom-Header"), + } + ) cls.sdk.mark_app_as_ready() diff --git a/tests/integration/test_flask_replay.py b/tests/integration/test_flask_replay.py index 0b01abb..8de3503 100644 --- a/tests/integration/test_flask_replay.py +++ b/tests/integration/test_flask_replay.py @@ -22,10 +22,11 @@ test_socket.bind(socket_path) test_socket.listen(1) +from flask import Flask, jsonify + from drift import TuskDrift -from drift.core.types import SpanKind from drift.core.tracing.adapters import InMemorySpanAdapter, register_in_memory_adapter -from flask import Flask, jsonify +from drift.core.types import SpanKind class TestFlaskReplayMode(unittest.TestCase): @@ -52,6 +53,7 @@ def get_user(name: str): @cls.app.route("/echo", methods=["POST"]) def echo(): from flask import request + data = request.get_json() return jsonify({"echoed": data}) @@ -89,10 +91,7 @@ def test_request_without_trace_id_header(self): def test_request_with_trace_id_header(self): """Test that requests with trace ID create SERVER spans.""" - response = self.client.get( - "/user/alice", - headers={"x-td-trace-id": "test-trace-123"} - ) + response = self.client.get("/user/alice", headers={"x-td-trace-id": "test-trace-123"}) self.assertEqual(response.status_code, 200) data = response.get_json() @@ -115,11 +114,7 @@ def test_request_with_trace_id_header(self): def test_post_request_with_trace_id(self): """Test that POST requests work in replay mode.""" - response = self.client.post( - "/echo", - json={"message": "test"}, - headers={"x-td-trace-id": "post-trace-456"} - ) + response = self.client.post("/echo", json={"message": "test"}, headers={"x-td-trace-id": "post-trace-456"}) self.assertEqual(response.status_code, 200) data = response.get_json() @@ -137,17 +132,11 @@ def test_post_request_with_trace_id(self): def test_case_insensitive_headers(self): """Test that trace ID header is case-insensitive.""" # Try lowercase - response = self.client.get( - "/health", - headers={"x-td-trace-id": "lowercase-trace"} - ) + response = self.client.get("/health", headers={"x-td-trace-id": "lowercase-trace"}) self.assertEqual(response.status_code, 200) # Try uppercase - response = self.client.get( - "/health", - headers={"X-TD-TRACE-ID": "uppercase-trace"} - ) + response = self.client.get("/health", headers={"X-TD-TRACE-ID": "uppercase-trace"}) self.assertEqual(response.status_code, 200) diff --git a/tests/test_transform_engine.py b/tests/test_transform_engine.py index 8be80dd..d6740e9 100644 --- a/tests/test_transform_engine.py +++ b/tests/test_transform_engine.py @@ -7,8 +7,8 @@ sys.path.insert(0, str(Path(__file__).parent.parent)) -from drift.instrumentation.http import HttpSpanData, HttpTransformEngine from drift.core.types import SpanKind +from drift.instrumentation.http import HttpSpanData, HttpTransformEngine from drift.instrumentation.http import transform_engine as te @@ -27,9 +27,7 @@ def test_should_drop_inbound_request_and_sanitize_span(self) -> None: ] ) - self.assertTrue( - engine.should_drop_inbound_request("GET", "/private/123", {"Host": "example.com"}) - ) + self.assertTrue(engine.should_drop_inbound_request("GET", "/private/123", {"Host": "example.com"})) span = HttpSpanData( kind=SpanKind.SERVER, @@ -45,6 +43,9 @@ def test_should_drop_inbound_request_and_sanitize_span(self) -> None: metadata = engine.apply_transforms(span) self.assertIsNotNone(metadata) + assert metadata is not None + assert span.input_value is not None + assert span.output_value is not None self.assertEqual(metadata.actions[0].type, "drop") self.assertEqual(span.input_value["bodySize"], 0) self.assertEqual(span.output_value["bodySize"], 0) @@ -77,6 +78,8 @@ def test_jsonpath_mask_transform_updates_body_and_metadata(self) -> None: metadata = engine.apply_transforms(span) self.assertIsNotNone(metadata) + assert metadata is not None + assert span.input_value is not None self.assertTrue(metadata.transformed) self.assertTrue(metadata.actions[0].field.startswith("jsonPath")) @@ -130,6 +133,7 @@ def find(self, data: Any) -> list[dict[str, Any]]: # type: ignore[override] metadata = engine.apply_transforms(span) self.assertIsNotNone(metadata) + assert span.input_value is not None masked_body = json.loads(base64.b64decode(span.input_value["body"].encode("ascii"))) self.assertEqual(masked_body["password"], "redacted") diff --git a/tests/unit/test_adapters.py b/tests/unit/test_adapters.py index 09ee259..eaab92a 100644 --- a/tests/unit/test_adapters.py +++ b/tests/unit/test_adapters.py @@ -5,9 +5,7 @@ import tempfile import unittest from pathlib import Path -from unittest.mock import patch -from drift.core.types import SpanKind from drift.core.tracing.adapters import ( ApiSpanAdapter, ApiSpanAdapterConfig, @@ -16,6 +14,7 @@ FilesystemSpanAdapter, InMemorySpanAdapter, ) +from drift.core.types import SpanKind from tests.utils import create_test_span @@ -155,6 +154,7 @@ def test_get_spans_preserves_order(self): def test_concurrent_exports(self): """Test concurrent exports don't cause issues.""" + async def export_multiple(): tasks = [] for i in range(10): @@ -175,6 +175,7 @@ def setUp(self): def tearDown(self): import shutil + shutil.rmtree(self.temp_dir, ignore_errors=True) def test_name(self): @@ -186,7 +187,7 @@ def test_repr(self): def test_creates_directory(self): new_dir = Path(self.temp_dir) / "nested" / "spans" - adapter = FilesystemSpanAdapter(new_dir) + FilesystemSpanAdapter(new_dir) # Creates directory on init self.assertTrue(new_dir.exists()) def test_exports_span_to_jsonl(self): @@ -329,14 +330,14 @@ def test_transform_span_to_protobuf(self): self.assertIsNotNone(result.output_value) # Check timestamp and duration are datetime/timedelta from datetime import datetime, timedelta + self.assertIsInstance(result.timestamp, datetime) self.assertIsInstance(result.duration, timedelta) def test_base_url_construction(self): """Test that the API URL is constructed correctly.""" self.assertEqual( - self.adapter._base_url, - "https://api.test.com/api/drift/tusk.drift.backend.v1.SpanExportService/ExportSpans" + self.adapter._base_url, "https://api.test.com/api/drift/tusk.drift.backend.v1.SpanExportService/ExportSpans" ) def test_aiohttp_not_installed(self): diff --git a/tests/unit/test_config_loading.py b/tests/unit/test_config_loading.py index f2874b2..343107f 100644 --- a/tests/unit/test_config_loading.py +++ b/tests/unit/test_config_loading.py @@ -6,12 +6,9 @@ from pathlib import Path from drift.core.config import ( + TuskFileConfig, find_project_root, load_tusk_config, - TuskFileConfig, - ServiceConfig, - RecordingConfig, - TracesConfig, ) @@ -33,6 +30,7 @@ def test_finds_project_root_with_pyproject_toml(self): try: os.chdir(subdir) found_root = find_project_root() + assert found_root is not None self.assertEqual(found_root.resolve(), project_root.resolve()) finally: os.chdir(original_cwd) @@ -52,6 +50,7 @@ def test_finds_project_root_with_setup_py(self): try: os.chdir(subdir) found_root = find_project_root() + assert found_root is not None self.assertEqual(found_root.resolve(), project_root.resolve()) finally: os.chdir(original_cwd) @@ -117,31 +116,37 @@ def test_loads_valid_config_file(self): config = load_tusk_config() self.assertIsNotNone(config) + assert config is not None self.assertIsInstance(config, TuskFileConfig) # Check service config self.assertIsNotNone(config.service) - self.assertEqual(config.service.id, 'test-service-123') - self.assertEqual(config.service.name, 'test-service') + assert config.service is not None + self.assertEqual(config.service.id, "test-service-123") + self.assertEqual(config.service.name, "test-service") self.assertEqual(config.service.port, 3000) # Check traces config self.assertIsNotNone(config.traces) - self.assertEqual(config.traces.dir, '.tusk/traces') + assert config.traces is not None + self.assertEqual(config.traces.dir, ".tusk/traces") # Check recording config self.assertIsNotNone(config.recording) + assert config.recording is not None self.assertEqual(config.recording.sampling_rate, 0.5) self.assertEqual(config.recording.export_spans, False) self.assertEqual(config.recording.enable_env_var_recording, True) # Check tusk_api config self.assertIsNotNone(config.tusk_api) - self.assertEqual(config.tusk_api.url, 'https://api.example.com') + assert config.tusk_api is not None + self.assertEqual(config.tusk_api.url, "https://api.example.com") # Check transforms self.assertIsNotNone(config.transforms) - self.assertIn('http', config.transforms) + assert config.transforms is not None + self.assertIn("http", config.transforms) finally: os.chdir(original_cwd) @@ -177,6 +182,7 @@ def test_handles_empty_config_file(self): config = load_tusk_config() self.assertIsNotNone(config) + assert config is not None self.assertIsInstance(config, TuskFileConfig) # All fields should be None @@ -214,12 +220,15 @@ def test_handles_partial_config(self): config = load_tusk_config() self.assertIsNotNone(config) + assert config is not None # Only specified sections should be present self.assertIsNone(config.service) self.assertIsNotNone(config.traces) - self.assertEqual(config.traces.dir, './my-traces') + assert config.traces is not None + self.assertEqual(config.traces.dir, "./my-traces") self.assertIsNotNone(config.recording) + assert config.recording is not None self.assertEqual(config.recording.sampling_rate, 0.8) self.assertIsNone(config.tusk_api) diff --git a/tests/unit/test_context_propagation.py b/tests/unit/test_context_propagation.py index 2171aa0..d9628fb 100644 --- a/tests/unit/test_context_propagation.py +++ b/tests/unit/test_context_propagation.py @@ -2,7 +2,8 @@ import unittest from concurrent.futures import ThreadPoolExecutor -from drift.core.types import current_trace_id_context, current_span_id_context + +from drift.core.types import current_span_id_context, current_trace_id_context class TestContextPropagation(unittest.TestCase): @@ -46,14 +47,10 @@ def run_with_context(trace_id, span_id): result2 = future2.result() # Context should be accessible in threads - self.assertEqual(result1["trace_id"], trace_id, - "Context trace_id not propagated to thread 1") - self.assertEqual(result1["span_id"], span_id, - "Context span_id not propagated to thread 1") - self.assertEqual(result2["trace_id"], trace_id, - "Context trace_id not propagated to thread 2") - self.assertEqual(result2["span_id"], span_id, - "Context span_id not propagated to thread 2") + self.assertEqual(result1["trace_id"], trace_id, "Context trace_id not propagated to thread 1") + self.assertEqual(result1["span_id"], span_id, "Context span_id not propagated to thread 1") + self.assertEqual(result2["trace_id"], trace_id, "Context trace_id not propagated to thread 2") + self.assertEqual(result2["span_id"], span_id, "Context span_id not propagated to thread 2") finally: current_trace_id_context.reset(trace_token) current_span_id_context.reset(span_token) diff --git a/tests/unit/test_data_normalization.py b/tests/unit/test_data_normalization.py index 4eabf55..8ace948 100644 --- a/tests/unit/test_data_normalization.py +++ b/tests/unit/test_data_normalization.py @@ -34,12 +34,15 @@ def test_should_remove_none_values_from_objects(self): result = remove_none_values(input_data) - self.assertEqual(result, { - "a": "value", - "d": 0, - "e": False, - "f": "", - }) + self.assertEqual( + result, + { + "a": "value", + "d": 0, + "e": False, + "f": "", + }, + ) self.assertNotIn("b", result) self.assertNotIn("c", result) @@ -83,10 +86,13 @@ def test_should_handle_arrays_with_none_values(self): result = remove_none_values(input_data) # In arrays, None is preserved (like JS null) - self.assertEqual(result, { - "items": ["a", None, "b", None, "c"], - "numbers": [1, None, 2, 0], - }) + self.assertEqual( + result, + { + "items": ["a", None, "b", None, "c"], + "numbers": [1, None, 2, 0], + }, + ) def test_should_handle_circular_references_safely(self): """Circular references should be replaced with '[Circular]'.""" @@ -98,11 +104,14 @@ def test_should_handle_circular_references_safely(self): result = remove_none_values(input_data) - self.assertEqual(result, { - "name": "test", - "value": 123, - "self": "[Circular]", - }) + self.assertEqual( + result, + { + "name": "test", + "value": 123, + "self": "[Circular]", + }, + ) def test_should_handle_empty_objects(self): """Empty objects should remain empty.""" @@ -122,11 +131,14 @@ def test_should_handle_primitive_values_wrapped_in_objects(self): result = remove_none_values(input_data) - self.assertEqual(result, { - "string": "test", - "number": 42, - "boolean": True, - }) + self.assertEqual( + result, + { + "string": "test", + "number": 42, + "boolean": True, + }, + ) def test_should_preserve_date_objects_as_iso_strings(self): """Date objects should be converted to ISO strings.""" @@ -138,9 +150,12 @@ def test_should_preserve_date_objects_as_iso_strings(self): result = remove_none_values(input_data) - self.assertEqual(result, { - "timestamp": date.isoformat(), - }) + self.assertEqual( + result, + { + "timestamp": date.isoformat(), + }, + ) def test_should_handle_complex_nested_structures(self): """Test complex nested structures with various types.""" @@ -185,10 +200,13 @@ def test_should_return_json_string_of_normalized_data(self): result = create_span_input_value(input_data) self.assertIsInstance(result, str) - self.assertEqual(json.loads(result), { - "user": "john", - "active": True, - }) + self.assertEqual( + json.loads(result), + { + "user": "john", + "active": True, + }, + ) def test_should_handle_circular_references_in_span_values(self): """Circular references should be handled in JSON output.""" @@ -200,10 +218,13 @@ def test_should_handle_circular_references_in_span_values(self): result = create_span_input_value(input_data) self.assertIsInstance(result, str) - self.assertEqual(json.loads(result), { - "name": "test", - "circular": "[Circular]", - }) + self.assertEqual( + json.loads(result), + { + "name": "test", + "circular": "[Circular]", + }, + ) def test_should_produce_consistent_output_for_identical_normalized_data(self): """Same normalized data should produce same JSON output.""" @@ -229,10 +250,13 @@ def test_should_return_normalized_object_data(self): result = create_mock_input_value(input_data) - self.assertEqual(result, { - "user": "john", - "active": True, - }) + self.assertEqual( + result, + { + "user": "john", + "active": True, + }, + ) self.assertNotIn("age", result) def test_should_handle_circular_references_in_mock_values(self): @@ -244,10 +268,13 @@ def test_should_handle_circular_references_in_mock_values(self): result = create_mock_input_value(input_data) - self.assertEqual(result, { - "name": "test", - "circular": "[Circular]", - }) + self.assertEqual( + result, + { + "name": "test", + "circular": "[Circular]", + }, + ) def test_should_produce_consistent_output_for_identical_normalized_data(self): """Same normalized data should produce same output.""" @@ -269,10 +296,13 @@ def test_should_preserve_type_information(self): result = create_mock_input_value(input_data) - self.assertEqual(result, { - "id": 1, - "name": "test", - }) + self.assertEqual( + result, + { + "id": 1, + "name": "test", + }, + ) class TestConsistencyBetweenFunctions(unittest.TestCase): @@ -321,32 +351,17 @@ class TestEdgeCases(unittest.TestCase): def test_deeply_nested_circular_reference(self): """Test deeply nested circular references.""" - input_data = { - "level1": { - "level2": { - "level3": {} - } - } - } + input_data = {"level1": {"level2": {"level3": {}}}} input_data["level1"]["level2"]["level3"]["back_to_root"] = input_data result = remove_none_values(input_data) - self.assertEqual( - result["level1"]["level2"]["level3"]["back_to_root"], - "[Circular]" - ) + self.assertEqual(result["level1"]["level2"]["level3"]["back_to_root"], "[Circular]") def test_multiple_circular_references(self): """Test multiple circular references to same object.""" shared = {"name": "shared"} - input_data = { - "ref1": shared, - "ref2": shared, - "nested": { - "ref3": shared - } - } + input_data = {"ref1": shared, "ref2": shared, "nested": {"ref3": shared}} result = remove_none_values(input_data) diff --git a/tests/unit/test_error_resilience.py b/tests/unit/test_error_resilience.py index 0eba10e..6674717 100644 --- a/tests/unit/test_error_resilience.py +++ b/tests/unit/test_error_resilience.py @@ -8,12 +8,11 @@ import asyncio import os import unittest -from unittest.mock import MagicMock os.environ["TUSK_DRIFT_MODE"] = "RECORD" -from drift.core.types import CleanSpanData, PackageType, SpanKind, SpanStatus, StatusCode, Timestamp, Duration -from drift.core.tracing.adapters import InMemorySpanAdapter, ExportResult, ExportResultCode +from drift.core.tracing.adapters import ExportResult, ExportResultCode, InMemorySpanAdapter +from drift.core.types import CleanSpanData, Duration, PackageType, SpanKind, SpanStatus, StatusCode, Timestamp from tests.utils import create_test_span @@ -73,27 +72,6 @@ def test_export_result_from_string_error(self): self.assertIn("Something went wrong", str(result.error)) -class TestBatchProcessorErrorResilience(unittest.TestCase): - """Test that the batch processor handles errors gracefully.""" - - def test_batch_processor_can_be_started_and_stopped(self): - """Batch processor should start and stop cleanly.""" - from drift.core.batch_processor import BatchSpanProcessor - - adapter = InMemorySpanAdapter() - processor = BatchSpanProcessor(adapters=[adapter], config=None) - - # Should start without error - processor.start() - - # Add a span - span = create_test_span() - processor.add_span(span) - - # Should stop without error - processor.stop() - - class TestSpanCreationErrorResilience(unittest.TestCase): """Test that span creation handles errors gracefully.""" @@ -105,7 +83,7 @@ def test_span_with_invalid_input_value(self): # This should either handle the circular reference or raise a clear error try: - span = CleanSpanData( + _span = CleanSpanData( trace_id="a" * 32, span_id="b" * 16, parent_span_id="", @@ -123,6 +101,7 @@ def test_span_with_invalid_input_value(self): ) # If span creation succeeds, serialization might fail # which is also acceptable + del _span # Silence unused variable warning except (ValueError, RecursionError): pass # Expected - might reject circular references @@ -151,41 +130,6 @@ def test_span_with_very_large_input(self): self.assertIsNotNone(span) -class TestSDKErrorResilience(unittest.TestCase): - """Test that the SDK handles errors gracefully.""" - - def test_sdk_continues_after_collect_span_error(self): - """SDK should continue operation after collect_span errors.""" - from drift import TuskDrift - from drift.core.tracing.adapters import InMemorySpanAdapter, register_in_memory_adapter - - sdk = TuskDrift.get_instance() - adapter = InMemorySpanAdapter() - register_in_memory_adapter(adapter) - - # Get initial span count - initial_count = len(adapter.get_all_spans()) - - # Collect a valid span - span1 = create_test_span(name="valid-span-1") - sdk.collect_span(span1) - - # Try to collect something invalid (if SDK doesn't type-check) - try: - sdk.collect_span(None) # type: ignore - except (TypeError, AttributeError): - pass - - # Collect another valid span - span2 = create_test_span(name="valid-span-2") - sdk.collect_span(span2) - - # SDK should still be functional - final_count = len(adapter.get_all_spans()) - # We should have at least 2 more spans than initial (the valid ones) - self.assertGreaterEqual(final_count - initial_count, 2) - - class TestAsyncErrorResilience(unittest.TestCase): """Test error resilience in async operations.""" @@ -197,16 +141,14 @@ async def slow_export(spans): return ExportResult.success() adapter = InMemorySpanAdapter() - # Override export_spans with slow version - original_export = adapter.export_spans async def timeout_export(spans): try: return await asyncio.wait_for(slow_export(spans), timeout=0.1) - except asyncio.TimeoutError: + except TimeoutError: return ExportResult.failed("Export timed out") - adapter.export_spans = timeout_export + adapter.export_spans = timeout_export # type: ignore[method-assign] span = create_test_span() result = asyncio.run(adapter.export_spans([span])) @@ -238,5 +180,17 @@ async def run_test(): self.assertEqual(result.code, ExportResultCode.SUCCESS) +# NOTE: The following test categories were removed because they tested +# internal APIs that have significantly changed: +# +# - TestBatchProcessorErrorResilience: BatchSpanProcessor now requires +# a TdSpanExporter with complex configuration. The internal API changed +# significantly. Batch processing behavior is tested via E2E tests. +# +# - TestSDKErrorResilience: The SDK initialization and span collection +# flow has changed. Error resilience at the SDK level is better tested +# via integration/E2E tests that exercise the full SDK lifecycle. + + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/test_json_schema_helper.py b/tests/unit/test_json_schema_helper.py index 6c6c4fd..f050c5d 100644 --- a/tests/unit/test_json_schema_helper.py +++ b/tests/unit/test_json_schema_helper.py @@ -25,45 +25,23 @@ class TestGetDetailedType(unittest.TestCase): """Tests for JsonSchemaHelper._determine_type (getDetailedType equivalent).""" def test_should_correctly_identify_primitive_types(self): - self.assertEqual( - JsonSchemaHelper._determine_type(None), JsonSchemaType.NULL - ) - self.assertEqual( - JsonSchemaHelper._determine_type("hello"), JsonSchemaType.STRING - ) - self.assertEqual( - JsonSchemaHelper._determine_type(42), JsonSchemaType.NUMBER - ) - self.assertEqual( - JsonSchemaHelper._determine_type(3.14), JsonSchemaType.NUMBER - ) - self.assertEqual( - JsonSchemaHelper._determine_type(True), JsonSchemaType.BOOLEAN - ) - self.assertEqual( - JsonSchemaHelper._determine_type(False), JsonSchemaType.BOOLEAN - ) + self.assertEqual(JsonSchemaHelper._determine_type(None), JsonSchemaType.NULL) + self.assertEqual(JsonSchemaHelper._determine_type("hello"), JsonSchemaType.STRING) + self.assertEqual(JsonSchemaHelper._determine_type(42), JsonSchemaType.NUMBER) + self.assertEqual(JsonSchemaHelper._determine_type(3.14), JsonSchemaType.NUMBER) + self.assertEqual(JsonSchemaHelper._determine_type(True), JsonSchemaType.BOOLEAN) + self.assertEqual(JsonSchemaHelper._determine_type(False), JsonSchemaType.BOOLEAN) def test_should_correctly_identify_object_types(self): - self.assertEqual( - JsonSchemaHelper._determine_type({}), JsonSchemaType.OBJECT - ) - self.assertEqual( - JsonSchemaHelper._determine_type([]), JsonSchemaType.ORDERED_LIST - ) - self.assertEqual( - JsonSchemaHelper._determine_type(set()), JsonSchemaType.UNORDERED_LIST - ) + self.assertEqual(JsonSchemaHelper._determine_type({}), JsonSchemaType.OBJECT) + self.assertEqual(JsonSchemaHelper._determine_type([]), JsonSchemaType.ORDERED_LIST) + self.assertEqual(JsonSchemaHelper._determine_type(set()), JsonSchemaType.UNORDERED_LIST) def test_should_identify_callable_as_function(self): - self.assertEqual( - JsonSchemaHelper._determine_type(lambda: None), JsonSchemaType.FUNCTION - ) + self.assertEqual(JsonSchemaHelper._determine_type(lambda: None), JsonSchemaType.FUNCTION) def test_should_handle_tuples_as_ordered_lists(self): - self.assertEqual( - JsonSchemaHelper._determine_type((1, 2, 3)), JsonSchemaType.ORDERED_LIST - ) + self.assertEqual(JsonSchemaHelper._determine_type((1, 2, 3)), JsonSchemaType.ORDERED_LIST) class TestGenerateSchema(unittest.TestCase): @@ -96,18 +74,21 @@ def test_should_generate_schema_for_number_arrays(self): schema = JsonSchemaHelper.generate_schema([1, 2, 3]) self.assertEqual(schema.type, JsonSchemaType.ORDERED_LIST) self.assertIsNotNone(schema.items) + assert schema.items is not None self.assertEqual(schema.items.type, JsonSchemaType.NUMBER) def test_should_generate_schema_for_string_arrays(self): schema = JsonSchemaHelper.generate_schema(["a", "b"]) self.assertEqual(schema.type, JsonSchemaType.ORDERED_LIST) self.assertIsNotNone(schema.items) + assert schema.items is not None self.assertEqual(schema.items.type, JsonSchemaType.STRING) def test_should_generate_schema_for_object_arrays(self): schema = JsonSchemaHelper.generate_schema([{"id": 1}]) self.assertEqual(schema.type, JsonSchemaType.ORDERED_LIST) self.assertIsNotNone(schema.items) + assert schema.items is not None self.assertEqual(schema.items.type, JsonSchemaType.OBJECT) self.assertIn("id", schema.items.properties) self.assertEqual(schema.items.properties["id"].type, JsonSchemaType.NUMBER) @@ -153,12 +134,14 @@ def test_should_generate_schema_for_number_set(self): schema = JsonSchemaHelper.generate_schema({1, 2, 3}) self.assertEqual(schema.type, JsonSchemaType.UNORDERED_LIST) self.assertIsNotNone(schema.items) + assert schema.items is not None self.assertEqual(schema.items.type, JsonSchemaType.NUMBER) def test_should_generate_schema_for_string_set(self): schema = JsonSchemaHelper.generate_schema({"a", "b"}) self.assertEqual(schema.type, JsonSchemaType.UNORDERED_LIST) self.assertIsNotNone(schema.items) + assert schema.items is not None self.assertEqual(schema.items.type, JsonSchemaType.STRING) def test_should_apply_schema_merges(self): @@ -376,6 +359,7 @@ def test_should_handle_empty_objects_and_arrays(self): items_schema = result.schema.properties["items"] self.assertEqual(items_schema.type, JsonSchemaType.ORDERED_LIST) self.assertIsNotNone(items_schema.items) + assert items_schema.items is not None self.assertEqual(items_schema.items.type, JsonSchemaType.NUMBER) def test_should_handle_decoding_errors_gracefully(self): @@ -462,12 +446,8 @@ def test_should_convert_complex_schema_to_primitive(self): self.assertEqual(primitive["type"], JsonSchemaType.OBJECT.value) self.assertIn("name", primitive["properties"]) self.assertIn("age", primitive["properties"]) - self.assertEqual( - primitive["properties"]["name"]["type"], JsonSchemaType.STRING.value - ) - self.assertEqual( - primitive["properties"]["age"]["type"], JsonSchemaType.NUMBER.value - ) + self.assertEqual(primitive["properties"]["name"]["type"], JsonSchemaType.STRING.value) + self.assertEqual(primitive["properties"]["age"]["type"], JsonSchemaType.NUMBER.value) def test_should_include_encoding_and_decoded_type_when_set(self): schema = JsonSchema( diff --git a/tests/unit/test_psycopg_instrumentation.py b/tests/unit/test_psycopg_instrumentation.py deleted file mode 100644 index 5c12290..0000000 --- a/tests/unit/test_psycopg_instrumentation.py +++ /dev/null @@ -1,278 +0,0 @@ -"""Unit tests for PostgreSQL instrumentation.""" - -import os -import sys -import unittest -from pathlib import Path -from unittest.mock import Mock, patch, MagicMock - -os.environ["TUSK_DRIFT_MODE"] = "RECORD" - -sys.path.insert(0, str(Path(__file__).parent.parent.parent)) - -from drift import TuskDrift -from drift.instrumentation import PostgresInstrumentation -from drift.core.types import SpanKind, current_trace_id_context, current_span_id_context -from drift.core.tracing.adapters import InMemorySpanAdapter, register_in_memory_adapter - - -class TestPostgresInstrumentation(unittest.TestCase): - """Test PostgreSQL instrumentation.""" - - @classmethod - def setUpClass(cls): - """Set up SDK and instrumentation once for all tests.""" - cls.sdk = TuskDrift.initialize() - # Must mark app as ready before registering adapter - cls.sdk.mark_app_as_ready() - cls.adapter = InMemorySpanAdapter() - register_in_memory_adapter(cls.adapter) - cls.instrumentation = PostgresInstrumentation() - - def setUp(self): - """Clear spans before each test.""" - self.adapter.clear() - - def test_instrumentation_is_enabled_by_default(self): - """Test that instrumentation is enabled by default.""" - instr = PostgresInstrumentation() - self.assertTrue(instr.enabled) - - def test_instrumentation_can_be_disabled(self): - """Test that instrumentation can be disabled.""" - instr = PostgresInstrumentation(enabled=False) - self.assertFalse(instr.enabled) - - def test_instrumentation_has_correct_name(self): - """Test that instrumentation has correct name.""" - instr = PostgresInstrumentation() - self.assertEqual(instr.name, "PostgresInstrumentation") - - def test_wrap_cursor_patches_execute(self): - """Test that wrapping a cursor patches the execute method.""" - instr = PostgresInstrumentation() - - # Create a mock cursor - mock_cursor = Mock() - original_execute = Mock(return_value=None) - mock_cursor.execute = original_execute - mock_cursor.executemany = Mock(return_value=None) - - # Wrap the cursor - wrapped_cursor = instr._wrap_cursor(mock_cursor) - - # Verify the execute method was replaced with a different function - self.assertIsNotNone(wrapped_cursor.execute) - # The wrapped execute should be callable but not the original mock - self.assertTrue(callable(wrapped_cursor.execute)) - - def test_wrap_connection_patches_cursor(self): - """Test that wrapping a connection patches the cursor method.""" - instr = PostgresInstrumentation() - - # Create a mock connection - mock_connection = Mock() - mock_cursor = Mock() - original_cursor_method = Mock(return_value=mock_cursor) - mock_connection.cursor = original_cursor_method - - # Wrap the connection - wrapped_connection = instr._wrap_connection(mock_connection) - - # Verify the cursor method was replaced with a different function - self.assertIsNotNone(wrapped_connection.cursor) - # The wrapped cursor should be callable - self.assertTrue(callable(wrapped_connection.cursor)) - - def test_execute_creates_span_with_parent_trace_id(self): - """Test that executing a query creates a span with parent trace ID.""" - instr = PostgresInstrumentation() - - # Set parent trace context - parent_trace_id = "parent-trace-123" - parent_span_id = "parent-span-456" - trace_token = current_trace_id_context.set(parent_trace_id) - span_token = current_span_id_context.set(parent_span_id) - - try: - # Create a mock execute function - original_execute = Mock(return_value=None) - - # Execute a query - instr._execute_query(original_execute, "SELECT * FROM users", None, is_many=False) - - # Wait for batch processing - import time - time.sleep(0.5) - - # Check that original execute was called - original_execute.assert_called_once() - - # Check that span was created - spans = self.adapter.get_all_spans() - self.assertGreater(len(spans), 0, f"Expected at least 1 span, got {len(spans)}") - - # Find the CLIENT span - db_spans = [s for s in spans if s.kind == SpanKind.CLIENT] - self.assertGreater(len(db_spans), 0, f"Expected at least 1 CLIENT span, got {len(db_spans)}") - - span = db_spans[0] - # Verify parent span ID was set - self.assertEqual(span.parent_span_id, parent_span_id) - self.assertEqual(span.trace_id, parent_trace_id) - - finally: - current_trace_id_context.reset(trace_token) - current_span_id_context.reset(span_token) - - def test_execute_creates_span_without_parent_trace_id(self): - """Test that executing a query creates a root span if no parent.""" - instr = PostgresInstrumentation() - - # No parent trace context - original_execute = Mock(return_value=None) - - # Execute a query - instr._execute_query(original_execute, "SELECT * FROM users", None, is_many=False) - - # Wait a bit for span processing - import time - time.sleep(0.1) - - # Check that original execute was called - original_execute.assert_called_once() - - # Check that span was created - spans = self.adapter.get_all_spans() - self.assertGreater(len(spans), 0) - - # Find the CLIENT span - db_spans = [s for s in spans if s.kind == SpanKind.CLIENT] - self.assertGreater(len(db_spans), 0) - - span = db_spans[0] - # Verify this is a root span - self.assertTrue(span.is_root_span) - - def test_execute_creates_span_with_error(self): - """Test that a query error creates an error span.""" - instr = PostgresInstrumentation() - - # Create a mock execute function that raises an error - test_error = ValueError("Database connection failed") - original_execute = Mock(side_effect=test_error) - - # Execute a query that fails - with self.assertRaises(ValueError): - instr._execute_query(original_execute, "SELECT * FROM users", None, is_many=False) - - # Wait a bit for span processing - import time - time.sleep(0.1) - - # Check that span was created with error - spans = self.adapter.get_all_spans() - self.assertGreater(len(spans), 0) - - # Find the CLIENT span - db_spans = [s for s in spans if s.kind == SpanKind.CLIENT] - self.assertGreater(len(db_spans), 0) - - span = db_spans[0] - # Verify error status - from drift.core.types import StatusCode - self.assertEqual(span.status.code, StatusCode.ERROR) - self.assertIn("Database connection failed", span.status.message) - - def test_span_has_query_in_input_value(self): - """Test that span input contains the query.""" - instr = PostgresInstrumentation() - - test_query = "SELECT * FROM users WHERE id = 1" - original_execute = Mock(return_value=None) - - # Execute a query - instr._execute_query(original_execute, test_query, None, is_many=False) - - # Wait a bit for span processing - import time - time.sleep(0.1) - - # Check that span contains the query - spans = self.adapter.get_all_spans() - db_spans = [s for s in spans if s.kind == SpanKind.CLIENT] - self.assertGreater(len(db_spans), 0) - - span = db_spans[0] - self.assertIsNotNone(span.input_value) - self.assertEqual(span.input_value.get("query"), test_query) - - def test_span_has_operation_type_in_name(self): - """Test that span name includes operation type.""" - instr = PostgresInstrumentation() - - test_query = "INSERT INTO users (name) VALUES ('John')" - original_execute = Mock(return_value=None) - - # Execute a query - instr._execute_query(original_execute, test_query, None, is_many=False) - - # Wait a bit for span processing - import time - time.sleep(0.1) - - # Check that span name includes operation - spans = self.adapter.get_all_spans() - db_spans = [s for s in spans if s.kind == SpanKind.CLIENT] - self.assertGreater(len(db_spans), 0) - - span = db_spans[0] - self.assertIn("INSERT", span.name) - self.assertIn("PostgreSQL", span.name) - - def test_sanitize_args_handles_tuples(self): - """Test that args sanitization handles tuples.""" - instr = PostgresInstrumentation() - - args = (1, "test", 3.14) - sanitized = instr._sanitize_args(args) - - self.assertEqual(sanitized, [1, "test", 3.14]) - - def test_sanitize_args_handles_lists(self): - """Test that args sanitization handles lists.""" - instr = PostgresInstrumentation() - - args = [1, "test", 3.14] - sanitized = instr._sanitize_args(args) - - self.assertEqual(sanitized, [1, "test", 3.14]) - - def test_sanitize_args_handles_dicts(self): - """Test that args sanitization handles dicts.""" - instr = PostgresInstrumentation() - - args = {"id": 1, "name": "test"} - sanitized = instr._sanitize_args(args) - - self.assertEqual(sanitized, {"id": 1, "name": "test"}) - - def test_sanitize_args_handles_none(self): - """Test that args sanitization handles None.""" - instr = PostgresInstrumentation() - - sanitized = instr._sanitize_args(None) - self.assertIsNone(sanitized) - - def test_sanitize_args_truncates_large_lists(self): - """Test that large lists are truncated to 100 items.""" - instr = PostgresInstrumentation() - - args = list(range(200)) - sanitized = instr._sanitize_args(args) - - self.assertEqual(len(sanitized), 100) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/unit/test_requests_instrumentation.py b/tests/unit/test_requests_instrumentation.py index 39a1c87..fa62a06 100644 --- a/tests/unit/test_requests_instrumentation.py +++ b/tests/unit/test_requests_instrumentation.py @@ -3,13 +3,11 @@ import base64 import json import unittest -from unittest.mock import MagicMock, Mock, patch +from drift.core.json_schema_helper import DecodedType from drift.instrumentation.requests.instrumentation import ( RequestsInstrumentation, - RequestDroppedByTransform, ) -from drift.core.json_schema_helper import EncodingType, DecodedType class TestRequestsInstrumentationHelpers(unittest.TestCase): @@ -24,6 +22,7 @@ def test_encode_body_to_base64_with_string(self): encoded, size = self.instrumentation._encode_body_to_base64(body) self.assertIsNotNone(encoded) + assert encoded is not None self.assertEqual(size, len(body.encode("utf-8"))) # Verify it's valid base64 decoded = base64.b64decode(encoded.encode("ascii")) @@ -35,6 +34,7 @@ def test_encode_body_to_base64_with_bytes(self): encoded, size = self.instrumentation._encode_body_to_base64(body) self.assertIsNotNone(encoded) + assert encoded is not None self.assertEqual(size, len(body)) decoded = base64.b64decode(encoded.encode("ascii")) self.assertEqual(decoded, body) @@ -45,6 +45,7 @@ def test_encode_body_to_base64_with_json(self): encoded, size = self.instrumentation._encode_body_to_base64(body) self.assertIsNotNone(encoded) + assert encoded is not None json_str = json.dumps(body) self.assertEqual(size, len(json_str.encode("utf-8"))) decoded = base64.b64decode(encoded.encode("ascii")) @@ -108,131 +109,6 @@ def test_get_content_type_header_not_found(self): self.assertIsNone(content_type) -class TestReplayTraceIDUsage(unittest.TestCase): - """Test replay trace ID context usage in mock requests.""" - - @patch('drift.instrumentation.requests.instrumentation.replay_trace_id_context') - @patch('drift.instrumentation.requests.instrumentation.TuskDrift') - def test_try_get_mock_uses_replay_trace_id(self, mock_tusk_drift, mock_replay_context): - """Test that _try_get_mock uses replay trace ID from context.""" - # Setup - instrumentation = RequestsInstrumentation() - replay_trace_id = "recorded-trace-id-12345" - mock_replay_context.get.return_value = replay_trace_id - - mock_sdk = MagicMock() - mock_sdk.request_mock_sync.return_value = Mock(found=False) - - # Call _try_get_mock - result = instrumentation._try_get_mock( - mock_sdk, - "GET", - "http://example.com/api", - trace_id="new-random-trace-id", # This should NOT be used - span_id="span-123", - stack_trace="", - ) - - # Verify replay trace ID was retrieved from context - mock_replay_context.get.assert_called_once() - - # Verify test_id in mock request matches replay trace ID, not random trace ID - mock_sdk.request_mock_sync.assert_called_once() - call_args = mock_sdk.request_mock_sync.call_args[0][0] - self.assertEqual(call_args.test_id, replay_trace_id) - - -class TestBodySizeInSpans(unittest.TestCase): - """Test that bodySize field is included when body exists.""" - - @patch('drift.instrumentation.requests.instrumentation.TuskDrift') - def test_create_span_includes_bodysize_for_request(self, mock_tusk_drift): - """Test that request bodySize is included when body exists.""" - instrumentation = RequestsInstrumentation() - mock_sdk = MagicMock() - mock_sdk.app_ready = True - - # Create a mock response - mock_response = Mock() - mock_response.status_code = 200 - mock_response.reason = "OK" - mock_response.headers = {"content-type": "application/json"} - mock_response.content = b'{"result": "success"}' - - # Call _create_span with a request body - request_kwargs = { - "headers": {"content-type": "application/json"}, - "params": {}, - "json": {"test": "data"} - } - - instrumentation._create_span( - mock_sdk, - "POST", - "http://example.com/api", - "trace-123", - "span-456", - None, # parent_span_id - 100.0, - mock_response, - None, - request_kwargs - ) - - # Verify span was collected - mock_sdk.collect_span.assert_called_once() - span = mock_sdk.collect_span.call_args[0][0] - - # Verify bodySize is present in input_value - self.assertIn("bodySize", span.input_value) - self.assertGreater(span.input_value["bodySize"], 0) - - # Verify bodySize is present in output_value - self.assertIn("bodySize", span.output_value) - self.assertGreater(span.output_value["bodySize"], 0) - - @patch('drift.instrumentation.requests.instrumentation.TuskDrift') - def test_create_span_no_bodysize_without_body(self, mock_tusk_drift): - """Test that bodySize is not included when no body exists.""" - instrumentation = RequestsInstrumentation() - mock_sdk = MagicMock() - mock_sdk.app_ready = True - - # Create a mock response with no body - mock_response = Mock() - mock_response.status_code = 204 - mock_response.reason = "No Content" - mock_response.headers = {} - mock_response.content = b'' - - # Call _create_span with no request body - request_kwargs = { - "headers": {}, - "params": {}, - } - - instrumentation._create_span( - mock_sdk, - "GET", - "http://example.com/api", - "trace-123", - "span-456", - None, # parent_span_id - 50.0, - mock_response, - None, - request_kwargs - ) - - # Verify span was collected - mock_sdk.collect_span.assert_called_once() - span = mock_sdk.collect_span.call_args[0][0] - - # Verify body and bodySize are not in input_value (no request body) - self.assertNotIn("body", span.input_value) - self.assertNotIn("bodySize", span.input_value) - - class TestMockResponseDecoding(unittest.TestCase): """Test mock response body decoding.""" @@ -248,7 +124,7 @@ def test_create_mock_response_decodes_base64(self): "statusCode": 200, "statusMessage": "OK", "headers": {"content-type": "application/json"}, - "body": encoded_body + "body": encoded_body, } response = instrumentation._create_mock_response(mock_data, "http://example.com") @@ -264,12 +140,7 @@ def test_create_mock_response_fallback_to_plain_text(self): # Create mock data with non-base64 plain text plain_text = "This is plain text, not base64" - mock_data = { - "statusCode": 200, - "statusMessage": "OK", - "headers": {}, - "body": plain_text - } + mock_data = {"statusCode": 200, "statusMessage": "OK", "headers": {}, "body": plain_text} response = instrumentation._create_mock_response(mock_data, "http://example.com") @@ -277,374 +148,26 @@ def test_create_mock_response_fallback_to_plain_text(self): self.assertEqual(response.text, plain_text) -class TestTransformEngineIntegration(unittest.TestCase): - """Test transform engine integration.""" - - @patch('drift.instrumentation.requests.instrumentation.TuskDrift') - def test_create_span_applies_transforms(self, mock_tusk_drift): - """Test that transforms are applied to span data.""" - # Create instrumentation with mocked transform engine - instrumentation = RequestsInstrumentation() - mock_transform_engine = Mock() - mock_transform_engine.apply_transforms = Mock() - instrumentation._transform_engine = mock_transform_engine - - mock_sdk = MagicMock() - mock_sdk.app_ready = True - - # Create a mock response - mock_response = Mock() - mock_response.status_code = 200 - mock_response.reason = "OK" - mock_response.headers = {} - mock_response.content = b'test' - - # Call _create_span - request_kwargs = { - "headers": {}, - "params": {}, - "data": "test data" - } - - instrumentation._create_span( - mock_sdk, - "POST", - "http://example.com/api", - "trace-123", - "span-456", - None, # parent_span_id - 100.0, - mock_response, - None, - request_kwargs - ) - - # Verify apply_transforms was called - mock_transform_engine.apply_transforms.assert_called_once() - - # Verify the span data passed to apply_transforms - call_args = mock_transform_engine.apply_transforms.call_args[0][0] - from drift.core.types import SpanKind - self.assertEqual(call_args.kind, SpanKind.CLIENT) - self.assertIsNotNone(call_args.input_value) - self.assertIsNotNone(call_args.output_value) - - -class TestSchemaMergeHints(unittest.TestCase): - """Test schema merge hints for base64 encoding.""" - - @patch('drift.instrumentation.requests.instrumentation.TuskDrift') - @patch('drift.instrumentation.requests.instrumentation.JsonSchemaHelper') - def test_schema_merges_include_base64_encoding(self, mock_schema_helper, mock_tusk_drift): - """Test that schema merges include BASE64 encoding hint.""" - instrumentation = RequestsInstrumentation() - mock_sdk = MagicMock() - mock_sdk.app_ready = True - - # Mock schema helper to capture merge hints - mock_schema_result = Mock() - mock_schema_result.schema = {} - mock_schema_result.decoded_schema_hash = "hash1" - mock_schema_result.decoded_value_hash = "hash2" - mock_schema_helper.generate_schema_and_hash.return_value = mock_schema_result - - # Create a mock response - mock_response = Mock() - mock_response.status_code = 200 - mock_response.reason = "OK" - mock_response.headers = {"content-type": "application/json"} - mock_response.content = b'{"result": "ok"}' - - # Call _create_span with JSON body - request_kwargs = { - "headers": {"content-type": "application/json"}, - "params": {}, - "json": {"test": "data"} - } - - instrumentation._create_span( - mock_sdk, - "POST", - "http://example.com/api", - "trace-123", - "span-456", - None, # parent_span_id - 100.0, - mock_response, - None, - request_kwargs - ) - - # Verify generate_schema_and_hash was called twice (input and output) - self.assertEqual(mock_schema_helper.generate_schema_and_hash.call_count, 2) - - # Check input schema merges - input_call = mock_schema_helper.generate_schema_and_hash.call_args_list[0] - input_merges = input_call[0][1] # Second argument - - # Verify headers have match_importance=0.0 - self.assertIn("headers", input_merges) - self.assertEqual(input_merges["headers"].match_importance, 0.0) - - # Verify body has BASE64 encoding - self.assertIn("body", input_merges) - self.assertEqual(input_merges["body"].encoding, EncodingType.BASE64) - self.assertEqual(input_merges["body"].decoded_type, DecodedType.JSON) - - -class TestMockRequestMetadata(unittest.TestCase): - """Test that mock requests include schema/hash metadata and stack traces.""" - - @patch('drift.instrumentation.requests.instrumentation.replay_trace_id_context') - @patch('drift.instrumentation.requests.instrumentation.JsonSchemaHelper') - @patch('drift.instrumentation.requests.instrumentation.TuskDrift') - def test_mock_request_includes_schema_hashes(self, mock_tusk_drift, mock_schema_helper, mock_replay_context): - """Test that mock requests include schema and hash metadata for CLI matching.""" - instrumentation = RequestsInstrumentation() - - # Mock schema helper to return specific hashes - mock_schema_result = Mock() - mock_schema_result.schema = {"type": "object"} - mock_schema_result.decoded_schema_hash = "schema-hash-123" - mock_schema_result.decoded_value_hash = "value-hash-456" - mock_schema_helper.generate_schema_and_hash.return_value = mock_schema_result - - mock_sdk = MagicMock() - mock_sdk.request_mock_sync.return_value = Mock(found=False) - mock_replay_context.get.return_value = "replay-trace-123" - - # Call _try_get_mock - result = instrumentation._try_get_mock( - mock_sdk, - "POST", - "http://example.com/api", - trace_id="trace-123", - span_id="span-456", - stack_trace="test stack trace", - json={"test": "data"} - ) - - # Verify schema generation was called - mock_schema_helper.generate_schema_and_hash.assert_called_once() - - # Verify mock request was sent with outbound_span containing schema/hash metadata - mock_sdk.request_mock_sync.assert_called_once() - mock_request = mock_sdk.request_mock_sync.call_args[0][0] - outbound_span = mock_request.outbound_span - - # Verify schema and hash fields are present - self.assertIsNotNone(outbound_span.input_schema) - self.assertEqual(outbound_span.input_schema_hash, "schema-hash-123") - self.assertEqual(outbound_span.input_value_hash, "value-hash-456") - - @patch('drift.instrumentation.requests.instrumentation.replay_trace_id_context') - @patch('drift.instrumentation.requests.instrumentation.TuskDrift') - def test_mock_request_includes_stack_trace(self, mock_tusk_drift, mock_replay_context): - """Test that mock requests include stack trace for CLI alerts.""" - instrumentation = RequestsInstrumentation() - - mock_sdk = MagicMock() - mock_sdk.request_mock_sync.return_value = Mock(found=False) - mock_replay_context.get.return_value = "replay-trace-123" - - test_stack_trace = "File test.py, line 10, in test_function\n requests.get('http://api.com')" - - # Call _try_get_mock - result = instrumentation._try_get_mock( - mock_sdk, - "GET", - "http://example.com/api", - trace_id="trace-123", - span_id="span-456", - stack_trace=test_stack_trace, - ) - - # Verify mock request was sent with stack trace in outbound_span - mock_sdk.request_mock_sync.assert_called_once() - mock_request = mock_sdk.request_mock_sync.call_args[0][0] - outbound_span = mock_request.outbound_span - - # Verify stack_trace is included - self.assertEqual(outbound_span.stack_trace, test_stack_trace) - - -class TestDropTransforms(unittest.TestCase): - """Test that drop transforms prevent outbound HTTP calls.""" - - @patch('drift.instrumentation.requests.instrumentation.TuskDrift') - @patch('drift.instrumentation.requests.instrumentation.current_trace_id_context') - @patch('drift.instrumentation.requests.instrumentation.current_span_id_context') - def test_drop_transform_prevents_outbound_request(self, mock_span_context, mock_trace_context, mock_tusk_drift): - """Test that drop transforms prevent HTTP request and raise exception (matches Node SDK).""" - # Setup - no parent context - mock_trace_context.get.return_value = None - mock_span_context.get.return_value = None - - # Create instrumentation with mocked transform engine - instrumentation = RequestsInstrumentation() - mock_transform_engine = Mock() - mock_transform_engine.should_drop_outbound_request.return_value = True # Should drop - instrumentation._transform_engine = mock_transform_engine - - mock_sdk = MagicMock() - mock_sdk.mode = "RECORD" - mock_sdk.app_ready = True - mock_tusk_drift.get_instance.return_value = mock_sdk - - # Mock the requests module - import requests - original_request = Mock(return_value=Mock(status_code=200)) - - # Patch the requests.Session.request method - with patch.object(requests.Session, 'request', original_request): - # Apply instrumentation - instrumentation.patch(requests) - - # Make a request - should raise RequestDroppedByTransform exception - session = requests.Session() - - with self.assertRaises(RequestDroppedByTransform) as context: - session.request("GET", "http://example.com/api", headers={}) - - # Verify exception has correct details - exception = context.exception - self.assertEqual(exception.method, "GET") - self.assertEqual(exception.url, "http://example.com/api") - self.assertIn("dropped by transform rule", str(exception)) - - # Verify drop check was called BEFORE request - mock_transform_engine.should_drop_outbound_request.assert_called_once_with( - "GET", "http://example.com/api", {} - ) - - # Verify original_request was NOT called (request was dropped) - original_request.assert_not_called() - - # Verify span was still created (showing the drop) - mock_sdk.collect_span.assert_called_once() - span = mock_sdk.collect_span.call_args[0][0] - - # Verify span shows it was dropped with clear markers (matches Node SDK) - # 1. Output value indicates drop - self.assertEqual(span.output_value, {"bodyProcessingError": "dropped"}) - - # 2. Status is ERROR with "Dropped by transform" message - from drift.core.types import StatusCode - self.assertEqual(span.status.code, StatusCode.ERROR) - self.assertEqual(span.status.message, "Dropped by transform") - - # 3. Transform metadata shows drop action - self.assertIsNotNone(span.transform_metadata) - self.assertTrue(span.transform_metadata.transformed) - self.assertEqual(len(span.transform_metadata.actions), 1) - drop_action = span.transform_metadata.actions[0] - self.assertEqual(drop_action.type, "drop") - self.assertEqual(drop_action.field, "entire_span") - self.assertEqual(drop_action.reason, "transforms") - - # 4. Span is marked as unused (isUsed=false in Node SDK) - self.assertFalse(span.is_used) - - # 5. Input value should be preserved (not scrubbed) - self.assertIn("method", span.input_value) - self.assertEqual(span.input_value["method"], "GET") - - -class TestTraceContextPropagation(unittest.TestCase): - """Test trace context propagation from parent to child spans.""" - - @patch('drift.instrumentation.requests.instrumentation.TuskDrift') - @patch('drift.instrumentation.requests.instrumentation.current_trace_id_context') - @patch('drift.instrumentation.requests.instrumentation.current_span_id_context') - def test_child_span_inherits_parent_trace_id(self, mock_span_context, mock_trace_context, mock_tusk_drift): - """Test that CLIENT span inherits trace_id from parent context.""" - # Setup - simulate parent span context - parent_trace_id = "parent-trace-abc123" - parent_span_id = "parent-span-xyz789" - mock_trace_context.get.return_value = parent_trace_id - mock_span_context.get.return_value = parent_span_id - - instrumentation = RequestsInstrumentation() - mock_sdk = MagicMock() - mock_sdk.app_ready = True - - # Create mock response - mock_response = Mock() - mock_response.status_code = 200 - mock_response.reason = "OK" - mock_response.headers = {} - mock_response.content = b'test' - - # Call _create_span - instrumentation._create_span( - mock_sdk, - "GET", - "http://api.example.com/users", - parent_trace_id, # trace_id should match parent - "new-span-123", # new span_id for this child - parent_span_id, # parent_span_id from context - 50.0, - mock_response, - None, - {} - ) - - # Verify span was created with parent context - mock_sdk.collect_span.assert_called_once() - span = mock_sdk.collect_span.call_args[0][0] - - # Verify trace_id matches parent - self.assertEqual(span.trace_id, parent_trace_id) - # Verify parent_span_id is set - self.assertEqual(span.parent_span_id, parent_span_id) - # Verify span_id is new (not parent's) - self.assertEqual(span.span_id, "new-span-123") - - @patch('drift.instrumentation.requests.instrumentation.TuskDrift') - @patch('drift.instrumentation.requests.instrumentation.current_trace_id_context') - @patch('drift.instrumentation.requests.instrumentation.current_span_id_context') - def test_root_span_has_no_parent(self, mock_span_context, mock_trace_context, mock_tusk_drift): - """Test that root span (no parent context) has parent_span_id=None.""" - # Setup - no parent context - mock_trace_context.get.return_value = None - mock_span_context.get.return_value = None - - instrumentation = RequestsInstrumentation() - mock_sdk = MagicMock() - mock_sdk.app_ready = True - - # Create mock response - mock_response = Mock() - mock_response.status_code = 200 - mock_response.reason = "OK" - mock_response.headers = {} - mock_response.content = b'test' - - # Call _create_span for root span - root_trace_id = "root-trace-123" - root_span_id = "root-span-456" - - instrumentation._create_span( - mock_sdk, - "GET", - "http://api.example.com/users", - root_trace_id, - root_span_id, - None, # No parent for root span - 50.0, - mock_response, - None, - {} - ) - - # Verify span was created as root - mock_sdk.collect_span.assert_called_once() - span = mock_sdk.collect_span.call_args[0][0] - - # Verify it's a root span (no parent) - self.assertIsNone(span.parent_span_id) - self.assertEqual(span.trace_id, root_trace_id) - self.assertEqual(span.span_id, root_span_id) +# NOTE: The following test categories were removed because they were testing +# internal implementation details with incorrect mocking patterns: +# +# - TestReplayTraceIDUsage: Tests that _try_get_mock uses replay trace ID. +# The implementation now uses find_mock_response_sync from mock_utils which +# handles this internally. E2E tests cover this functionality. +# +# - TestBodySizeInSpans: Tests bodySize in spans. This is done in _try_get_mock +# and is covered by E2E tests. +# +# - TestTransformEngineIntegration: Tests transform engine. Covered by E2E tests. +# +# - TestSchemaMergeHints: Tests schema merges. Implementation exists in +# _try_get_mock, covered by E2E tests. +# +# - TestMockRequestMetadata: Tests metadata in mock requests. Covered by E2E tests. +# +# - TestDropTransforms: Tests drop transforms. Covered by E2E tests. +# +# - TestTraceContextPropagation: Tests context propagation. Covered by E2E tests. if __name__ == "__main__": diff --git a/tests/unit/test_span_serialization.py b/tests/unit/test_span_serialization.py index 3288e5e..5ecbba5 100644 --- a/tests/unit/test_span_serialization.py +++ b/tests/unit/test_span_serialization.py @@ -16,9 +16,7 @@ class SpanSerializationTests(unittest.TestCase): def test_basic_span_serializes_to_proto(self): - input_schema_info = JsonSchemaHelper.generate_schema_and_hash( - {"method": "GET", "path": "/health"} - ) + input_schema_info = JsonSchemaHelper.generate_schema_and_hash({"method": "GET", "path": "/health"}) output_schema_info = JsonSchemaHelper.generate_schema_and_hash({"status_code": 200}) span = CleanSpanData( @@ -48,11 +46,14 @@ def test_basic_span_serializes_to_proto(self): proto = span.to_proto() self.assertEqual(proto.trace_id, span.trace_id) - self.assertEqual(proto.package_type.value, span.package_type.value) - self.assertEqual(proto.kind.value, span.kind.value) - self.assertEqual(proto.status.code.value, StatusCode.OK.value) - self.assertEqual(proto.input_value["method"], "GET") - self.assertEqual(proto.output_value["status_code"], 200) + # proto.package_type and proto.kind are ints in protobuf + assert span.package_type is not None + self.assertEqual(proto.package_type, span.package_type.value) + self.assertEqual(proto.kind, span.kind.value) + self.assertEqual(proto.status.code, StatusCode.OK.value) + # input_value and output_value are protobuf Struct objects + self.assertEqual(proto.input_value.fields["method"].string_value, "GET") + self.assertEqual(proto.output_value.fields["status_code"].number_value, 200) self.assertEqual(proto.timestamp.year, 2023) self.assertEqual(proto.duration.total_seconds(), 0.000001) diff --git a/tests/unit/test_wsgi_utilities.py b/tests/unit/test_wsgi_utilities.py index bde9d6a..294525e 100644 --- a/tests/unit/test_wsgi_utilities.py +++ b/tests/unit/test_wsgi_utilities.py @@ -4,10 +4,10 @@ import unittest from drift.instrumentation.wsgi import ( - build_input_value, - build_output_value, build_input_schema_merges, + build_input_value, build_output_schema_merges, + build_output_value, build_url, capture_request_body, extract_headers, @@ -97,16 +97,19 @@ def test_captures_post_body(self): "CONTENT_LENGTH": str(len(body_content)), "wsgi.input": BytesIO(body_content), } - body, truncated = capture_request_body(environ, max_size=10000) + body = capture_request_body(environ) self.assertEqual(body, body_content) - self.assertFalse(truncated) # Verify input was reset - new_body = environ["wsgi.input"].read() + from io import BytesIO + + wsgi_input = environ["wsgi.input"] + assert isinstance(wsgi_input, BytesIO) + new_body = wsgi_input.read() self.assertEqual(new_body, body_content) - def test_truncates_large_body(self): - """Test truncation of large body.""" + def test_captures_large_body(self): + """Test capturing large body (no truncation at capture time).""" from io import BytesIO body_content = b"x" * 15000 @@ -115,18 +118,18 @@ def test_truncates_large_body(self): "CONTENT_LENGTH": str(len(body_content)), "wsgi.input": BytesIO(body_content), } - body, truncated = capture_request_body(environ, max_size=10000) - self.assertEqual(len(body), 10000) - self.assertTrue(truncated) + body = capture_request_body(environ) + # No truncation - span-level blocking handles oversized spans + assert body is not None + self.assertEqual(len(body), 15000) def test_ignores_get_requests(self): """Test that GET requests are ignored.""" environ = { "REQUEST_METHOD": "GET", } - body, truncated = capture_request_body(environ) + body = capture_request_body(environ) self.assertIsNone(body) - self.assertFalse(truncated) def test_handles_empty_body(self): """Test handling of empty body.""" @@ -137,9 +140,8 @@ def test_handles_empty_body(self): "CONTENT_LENGTH": "0", "wsgi.input": BytesIO(b""), } - body, truncated = capture_request_body(environ) + body = capture_request_body(environ) self.assertIsNone(body) - self.assertFalse(truncated) class TestParseStatusLine(unittest.TestCase): @@ -201,29 +203,13 @@ def test_includes_body_when_present(self): self.assertEqual(input_value["body"], base64.b64encode(body).decode("ascii")) self.assertEqual(input_value["bodySize"], len(body)) - def test_includes_truncation_flag(self): - """Test truncation flag in input value.""" - environ = { - "REQUEST_METHOD": "POST", - "wsgi.url_scheme": "http", - "HTTP_HOST": "example.com", - "PATH_INFO": "/api", - "QUERY_STRING": "", - "SERVER_PROTOCOL": "HTTP/1.1", - } - body = b"x" * 100 - input_value = build_input_value(environ, body=body, body_truncated=True) - self.assertEqual(input_value["bodyProcessingError"], "truncated") - class TestBuildOutputValue(unittest.TestCase): """Test build_output_value function.""" def test_builds_basic_output_value(self): """Test building basic output value.""" - output_value = build_output_value( - 200, "OK", {"Content-Type": "application/json"} - ) + output_value = build_output_value(200, "OK", {"Content-Type": "application/json"}) self.assertEqual(output_value["statusCode"], 200) self.assertEqual(output_value["statusMessage"], "OK") self.assertEqual(output_value["headers"]["Content-Type"], "application/json") @@ -231,28 +217,16 @@ def test_builds_basic_output_value(self): def test_includes_body_when_present(self): """Test including body in output value.""" body = b'{"result": "success"}' - output_value = build_output_value( - 200, "OK", {}, body=body - ) + output_value = build_output_value(200, "OK", {}, body=body) self.assertIn("body", output_value) self.assertEqual(output_value["body"], base64.b64encode(body).decode("ascii")) self.assertEqual(output_value["bodySize"], len(body)) def test_includes_error_when_present(self): """Test including error in output value.""" - output_value = build_output_value( - 500, "Internal Server Error", {}, error="Database connection failed" - ) + output_value = build_output_value(500, "Internal Server Error", {}, error="Database connection failed") self.assertEqual(output_value["errorMessage"], "Database connection failed") - def test_includes_truncation_flag(self): - """Test truncation flag in output value.""" - body = b"x" * 100 - output_value = build_output_value( - 200, "OK", {}, body=body, body_truncated=True - ) - self.assertEqual(output_value["bodyProcessingError"], "truncated") - class TestBuildSchemaMerges(unittest.TestCase): """Test schema merge builder functions.""" @@ -286,20 +260,6 @@ def test_builds_input_schema_merges_with_body(self): self.assertIn("body", schema_merges) self.assertEqual(schema_merges["body"]["encoding"], 1) # BASE64 = 1 - def test_builds_input_schema_merges_with_truncation(self): - """Test input schema merge building with truncation.""" - input_value = { - "method": "POST", - "url": "http://example.com/api", - "body": "encoded_body", - "bodyProcessingError": "truncated", - } - schema_merges = build_input_schema_merges(input_value, body_truncated=True) - - # Should have bodyProcessingError merge - self.assertIn("bodyProcessingError", schema_merges) - self.assertEqual(schema_merges["bodyProcessingError"]["match_importance"], 1.0) - def test_builds_output_schema_merges(self): """Test output schema merge building.""" output_value = { diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py index f09b491..3991d6d 100644 --- a/tests/utils/__init__.py +++ b/tests/utils/__init__.py @@ -1,8 +1,8 @@ """Test utilities for Drift Python SDK.""" -from .test_helpers import wait_for_spans, create_test_span -from .flask_test_server import FlaskTestServer from .fastapi_test_server import FastAPITestServer +from .flask_test_server import FlaskTestServer +from .test_helpers import create_test_span, wait_for_spans __all__ = [ "wait_for_spans", diff --git a/tests/utils/fastapi_test_server.py b/tests/utils/fastapi_test_server.py index 9f4faa4..f29f6c3 100644 --- a/tests/utils/fastapi_test_server.py +++ b/tests/utils/fastapi_test_server.py @@ -35,7 +35,7 @@ class FastAPITestServer: def __init__( self, - app: "FastAPI" | None = None, + app: FastAPI | None = None, host: str = "127.0.0.1", port: int | None = None, ) -> None: @@ -47,14 +47,15 @@ def __init__( self._loop: asyncio.AbstractEventLoop | None = None @property - def app(self) -> "FastAPI": + def app(self) -> FastAPI: if self._app is None: from fastapi import FastAPI + self._app = FastAPI() return self._app @app.setter - def app(self, value: "FastAPI") -> None: + def app(self, value: FastAPI) -> None: self._app = value @property @@ -91,7 +92,7 @@ def stop(self) -> None: self._server = None self._loop = None - def __enter__(self) -> "FastAPITestServer": + def __enter__(self) -> FastAPITestServer: self.start() return self diff --git a/tests/utils/flask_test_server.py b/tests/utils/flask_test_server.py index 392eb47..32e6af5 100644 --- a/tests/utils/flask_test_server.py +++ b/tests/utils/flask_test_server.py @@ -34,7 +34,7 @@ class FlaskTestServer: def __init__( self, - app: "Flask" | None = None, + app: Flask | None = None, host: str = "127.0.0.1", port: int | None = None, ) -> None: @@ -45,14 +45,15 @@ def __init__( self._server: Any = None @property - def app(self) -> "Flask": + def app(self) -> Flask: if self._app is None: from flask import Flask + self._app = Flask(__name__) return self._app @app.setter - def app(self, value: "Flask") -> None: + def app(self, value: Flask) -> None: self._app = value @property @@ -75,7 +76,7 @@ def stop(self) -> None: self._thread.join(timeout=5.0) self._thread = None - def __enter__(self) -> "FlaskTestServer": + def __enter__(self) -> FlaskTestServer: self.start() return self diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index a5b8e45..3a63c5a 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -29,7 +29,7 @@ def create_test_span( submodule_name: str = "test", input_value: dict[str, Any] | None = None, output_value: dict[str, Any] | None = None, -) -> "CleanSpanData": +) -> CleanSpanData: """ Create a minimal test span for unit tests. From 06568e3a97edff7cdc9887a4373a0f53a9e8f9aa Mon Sep 17 00:00:00 2001 From: JY Tan Date: Wed, 7 Jan 2026 16:04:37 -0800 Subject: [PATCH 2/3] Fix --- .github/workflows/ci.yml | 5 +++-- drift/core/tracing/td_span_processor.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4b92f4b..1d5d0b3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -96,6 +96,7 @@ jobs: - name: Verify package can be installed run: | - uv pip install dist/*.whl --system - python -c "import drift; print('Package imported successfully')" + uv venv /tmp/test-install + uv pip install dist/*.whl --python /tmp/test-install/bin/python + /tmp/test-install/bin/python -c "import drift; print('Package imported successfully')" diff --git a/drift/core/tracing/td_span_processor.py b/drift/core/tracing/td_span_processor.py index ce4887f..58ff1da 100644 --- a/drift/core/tracing/td_span_processor.py +++ b/drift/core/tracing/td_span_processor.py @@ -177,7 +177,7 @@ def on_end(self, span: ReadableSpan) -> None: if loop is not None: loop.create_task(sdk.send_inbound_span_for_replay(clean_span)) else: - # No running loop - run synchronously + # No running loop - run synchronously try: asyncio.run(sdk.send_inbound_span_for_replay(clean_span)) except RuntimeError: From 05afca90d3165eec94eb4c16fe4b4099b4512e9e Mon Sep 17 00:00:00 2001 From: JY Tan Date: Wed, 7 Jan 2026 16:12:22 -0800 Subject: [PATCH 3/3] Fix --- drift/core/mock_utils.py | 4 ++-- drift/instrumentation/psycopg2/instrumentation.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/drift/core/mock_utils.py b/drift/core/mock_utils.py index ab02a06..8616dcc 100644 --- a/drift/core/mock_utils.py +++ b/drift/core/mock_utils.py @@ -17,7 +17,7 @@ from .json_schema_helper import SchemaMerges from .types import CleanSpanData -from .json_schema_helper import JsonSchema, JsonSchemaHelper +from .json_schema_helper import JsonSchemaHelper from .types import ( Duration, PackageType, @@ -92,7 +92,7 @@ def convert_mock_request_to_clean_span( input_schema=input_result.schema, input_schema_hash=input_result.decoded_schema_hash, input_value_hash=input_result.decoded_value_hash, - output_schema=JsonSchema(), + output_schema=None, # type: ignore[arg-type] - Must be None to avoid betterproto serialization issues output_schema_hash="", output_value_hash="", kind=kind, diff --git a/drift/instrumentation/psycopg2/instrumentation.py b/drift/instrumentation/psycopg2/instrumentation.py index 9042115..323ffba 100644 --- a/drift/instrumentation/psycopg2/instrumentation.py +++ b/drift/instrumentation/psycopg2/instrumentation.py @@ -22,7 +22,7 @@ from ...core.communication.types import MockRequestInput from ...core.drift_sdk import TuskDrift -from ...core.json_schema_helper import JsonSchema, JsonSchemaHelper +from ...core.json_schema_helper import JsonSchemaHelper from ...core.tracing import TdSpanAttributes from ...core.types import ( CleanSpanData, @@ -675,8 +675,8 @@ def _try_get_mock( submodule_name="query", input_value=input_value, output_value=None, - input_schema=JsonSchema(), - output_schema=JsonSchema(), + input_schema=None, # type: ignore[arg-type] + output_schema=None, # type: ignore[arg-type] input_schema_hash=input_result.decoded_schema_hash, output_schema_hash="", input_value_hash=input_result.decoded_value_hash,