Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
name: CI

on:
push:
branches: [main]
pull_request:
branches: [main]

jobs:
lint:
name: Lint & Format
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v4
with:
version: "latest"

- name: Setup Python
run: uv python install 3.12

- name: Install dependencies
run: uv sync --all-extras

- name: Check formatting
run: uv run ruff format --check drift/ tests/

- name: Lint
run: uv run ruff check drift/ tests/

typecheck:
name: Type Check
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v4
with:
version: "latest"

- name: Setup Python
run: uv python install 3.12

- name: Install dependencies
run: uv sync --all-extras

- name: Type check
run: uv run ty check drift/ tests/

test:
name: Unit Tests
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v4
with:
version: "latest"

- name: Setup Python
run: uv python install 3.12

- name: Install dependencies
run: uv sync --all-extras

- name: Run unit tests
run: uv run pytest tests/unit/ -v

build:
name: Build Package
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v4
with:
version: "latest"

- name: Setup Python
run: uv python install 3.12

- name: Install dependencies
run: uv sync --all-extras

- name: Build package
run: uv build

- name: Verify package can be installed
run: |
uv venv /tmp/test-install
uv pip install dist/*.whl --python /tmp/test-install/bin/python
/tmp/test-install/bin/python -c "import drift; print('Package imported successfully')"

6 changes: 3 additions & 3 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ uv sync --all-extras
This project uses [ruff](https://docs.astral.sh/ruff/) for linting/formatting and [ty](https://github.com/astral-sh/ty) for type checking.

```bash
uv run ruff check drift/ --fix # Lint and auto-fix
uv run ruff format drift/ # Format
uv run ty check drift/ # Type check
uv run ruff check drift/ tests/ --fix # Lint and auto-fix
uv run ruff format drift/ tests/ # Format
uv run ty check drift/ tests/ # Type check
```

## Running Tests
Expand Down
10 changes: 7 additions & 3 deletions drift/core/communication/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
from typing import Any

from tusk.drift.core.v1 import GetMockRequest as ProtoGetMockRequest
from ..span_serialization import clean_span_to_proto

from ...version import MIN_CLI_VERSION, SDK_VERSION
from ..span_serialization import clean_span_to_proto
from ..types import CleanSpanData, calling_library_context
from .types import (
CliMessage,
Expand Down Expand Up @@ -566,7 +567,7 @@ async def _receive_response(self, request_id: str) -> MockResponseOutput:
response = cli_message.connect_response
if response.success:
logger.debug("CLI acknowledged connection")
self._session_id = response.session_id
# Note: session_id is not in the protobuf schema
else:
logger.error(f"CLI rejected connection: {response.error}")
continue
Expand All @@ -576,6 +577,8 @@ async def _receive_response(self, request_id: str) -> MockResponseOutput:

def _recv_exact(self, n: int) -> bytes | None:
"""Receive exactly n bytes from socket."""
if self._socket is None:
return None
data = bytearray()
while len(data) < n:
chunk = self._socket.recv(n - len(data))
Expand Down Expand Up @@ -655,7 +658,8 @@ def _handle_cli_message(self, message: CliMessage) -> MockResponseOutput:
if message.connect_response:
response = message.connect_response
if response.success:
self._session_id = response.session_id
logger.debug("CLI acknowledged connection")
# Note: session_id is not in the protobuf schema
return MockResponseOutput(found=False, error="Unexpected connect response")

return MockResponseOutput(found=False, error="Unknown message type")
Expand Down
65 changes: 56 additions & 9 deletions drift/core/communication/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,43 @@
CLIMessageType = MessageType


def _python_to_value(value: Any) -> Any:
"""Convert Python value to protobuf Value."""
from betterproto.lib.google.protobuf import ListValue, Value

if value is None:
from betterproto.lib.google.protobuf import NullValue

return Value(null_value=NullValue.NULL_VALUE) # type: ignore[arg-type]
elif isinstance(value, bool):
return Value(bool_value=value)
elif isinstance(value, (int, float)):
return Value(number_value=float(value))
elif isinstance(value, str):
return Value(string_value=value)
elif isinstance(value, dict):
from betterproto.lib.google.protobuf import Struct

struct = Struct()
struct.fields = {k: _python_to_value(v) for k, v in value.items()}
return Value(struct_value=struct)
elif isinstance(value, (list, tuple)):
list_val = ListValue(values=[_python_to_value(item) for item in value])
return Value(list_value=list_val)
else:
return Value(string_value=str(value))


def _dict_to_struct(data: dict[str, Any]) -> Any:
"""Convert Python dict to protobuf Struct."""
from betterproto.lib.google.protobuf import Struct

struct = Struct()
if data:
struct.fields = {k: _python_to_value(v) for k, v in data.items()}
return struct


@dataclass
class ConnectRequest:
"""Initial connection request from SDK to CLI.
Expand All @@ -102,11 +139,18 @@ class ConnectRequest:

def to_proto(self) -> ProtoConnectRequest:
"""Convert to protobuf message."""
from betterproto.lib.google.protobuf import Struct

# Convert metadata dict to protobuf Struct
metadata_struct = Struct()
if self.metadata:
metadata_struct.fields = {k: _python_to_value(v) for k, v in self.metadata.items()}

return ProtoConnectRequest(
service_id=self.service_id,
sdk_version=self.sdk_version,
min_cli_version=self.min_cli_version,
metadata=self.metadata,
metadata=metadata_struct,
)


Expand All @@ -129,10 +173,12 @@ class ConnectResponse:
@classmethod
def from_proto(cls, proto: ProtoConnectResponse) -> ConnectResponse:
"""Create from protobuf message."""
# Note: ProtoConnectResponse only has success and error fields
# cli_version and session_id are SDK-only extensions not in the protobuf schema
return cls(
success=proto.success,
cli_version=proto.cli_version or None,
session_id=proto.session_id or None,
cli_version=None,
session_id=None,
error=proto.error or None,
)

Expand Down Expand Up @@ -164,7 +210,7 @@ class GetMockRequest:

def to_proto(self) -> ProtoGetMockRequest:
"""Convert to protobuf message."""
span = dict_to_span(self.outbound_span) if self.outbound_span else None
span = dict_to_span(self.outbound_span) if self.outbound_span else ProtoSpan()
return ProtoGetMockRequest(
request_id=self.request_id,
test_id=self.test_id,
Expand Down Expand Up @@ -253,7 +299,8 @@ def dict_to_span(data: dict[str, Any]) -> ProtoSpan:
message=status_data.get("message", ""),
)
else:
status = ProtoSpanStatus(code=ProtoStatusCode.UNSET)
# UNSET is not a valid StatusCode in the proto - use UNSPECIFIED (0)
status = ProtoSpanStatus(code=ProtoStatusCode.UNSPECIFIED)

return ProtoSpan(
trace_id=data.get("trace_id", data.get("traceId", "")),
Expand All @@ -263,8 +310,8 @@ def dict_to_span(data: dict[str, Any]) -> ProtoSpan:
package_name=data.get("package_name", data.get("packageName", "")),
instrumentation_name=data.get("instrumentation_name", data.get("instrumentationName", "")),
submodule_name=data.get("submodule_name", data.get("submoduleName", "")),
input_value=data.get("input_value", data.get("inputValue", {})),
output_value=data.get("output_value", data.get("outputValue", {})),
input_value=_dict_to_struct(data.get("input_value", data.get("inputValue", {}))),
output_value=_dict_to_struct(data.get("output_value", data.get("outputValue", {}))),
input_schema_hash=data.get("input_schema_hash", data.get("inputSchemaHash", "")),
output_schema_hash=data.get("output_schema_hash", data.get("outputSchemaHash", "")),
input_value_hash=data.get("input_value_hash", data.get("inputValueHash", "")),
Expand Down Expand Up @@ -306,8 +353,8 @@ def span_to_proto(span: Any) -> ProtoSpan:
package_name=span.package_name or "",
instrumentation_name=span.instrumentation_name or "",
submodule_name=span.submodule_name or "",
input_value=span.input_value or {},
output_value=span.output_value or {},
input_value=_dict_to_struct(span.input_value or {}),
output_value=_dict_to_struct(span.output_value or {}),
input_schema_hash=span.input_schema_hash or "",
output_schema_hash=span.output_schema_hash or "",
input_value_hash=span.input_value_hash or "",
Expand Down
5 changes: 3 additions & 2 deletions drift/core/mock_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from .communication.types import MockResponseOutput
from .drift_sdk import TuskDrift
from .json_schema_helper import SchemaMerges
from .types import CleanSpanData, MockResponseOutput
from .types import CleanSpanData

from .json_schema_helper import JsonSchemaHelper
from .types import (
Expand Down Expand Up @@ -91,7 +92,7 @@ def convert_mock_request_to_clean_span(
input_schema=input_result.schema,
input_schema_hash=input_result.decoded_schema_hash,
input_value_hash=input_result.decoded_value_hash,
output_schema=None,
output_schema=None, # type: ignore[arg-type] - Must be None to avoid betterproto serialization issues
output_schema_hash="",
output_value_hash="",
kind=kind,
Expand Down
10 changes: 6 additions & 4 deletions drift/core/span_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)

from .json_schema_helper import DecodedType, EncodingType, JsonSchema, JsonSchemaType
from .types import CleanSpanData
from .types import CleanSpanData, PackageType


def _value_to_proto(value: Any) -> ProtoValue:
Expand All @@ -37,7 +37,9 @@ def _value_to_proto(value: Any) -> ProtoValue:
proto_value = ProtoValue()

if value is None:
proto_value.null_value = 0
from betterproto.lib.google.protobuf import NullValue

proto_value.null_value = NullValue.NULL_VALUE # type: ignore[assignment]
elif isinstance(value, bool):
proto_value.bool_value = value
elif isinstance(value, (int, float)):
Expand Down Expand Up @@ -94,7 +96,7 @@ def clean_span_to_proto(span: CleanSpanData) -> ProtoSpan:
package_name=span.package_name,
instrumentation_name=span.instrumentation_name,
submodule_name=span.submodule_name,
package_type=span.package_type.value if span.package_type else 0,
package_type=span.package_type.value if span.package_type else PackageType.UNSPECIFIED.value, # type: ignore[arg-type]
environment=span.environment,
kind=span.kind.value if hasattr(span.kind, "value") else span.kind,
input_value=_dict_to_struct(span.input_value),
Expand All @@ -119,7 +121,7 @@ def clean_span_to_proto(span: CleanSpanData) -> ProtoSpan:
seconds=span.duration.seconds,
microseconds=span.duration.nanos // 1000,
),
metadata=_metadata_to_dict(span.metadata),
metadata=_dict_to_struct(_metadata_to_dict(span.metadata)),
)


Expand Down
1 change: 1 addition & 0 deletions drift/core/trace_blocking_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def get_instance(cls) -> TraceBlockingManager:
if cls._instance is None:
cls._instance = TraceBlockingManager()
cls._instance._start_cleanup_thread()
assert cls._instance is not None
return cls._instance

def block_trace(self, trace_id: str, reason: str = "size_limit") -> None:
Expand Down
Loading