Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ repos:
rev: v2.4.1
hooks:
- id: codespell
entry: codespell -q 3 -f --skip=".git,.github,README.md" --ignore-words-list="astroid"
entry: codespell -q 3 -f --skip=".git,.github,README.md" --ignore-words-list="astroid,braket,te"

# Python code security
- repo: https://github.com/PyCQA/bandit
Expand All @@ -57,21 +57,21 @@ repos:
- id: nbstripout
args: [--keep-id]

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.7
hooks:
- id: ruff
args: [--fix]
- id: ruff-format
# - repo: https://github.com/astral-sh/ruff-pre-commit
# rev: v0.11.7
# hooks:
# - id: ruff
# args: [--fix]
# - id: ruff-format

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.15.0
hooks:
- id: mypy
additional_dependencies:
- "types-PyYAML"
- "types-requests"
- "types-setuptools"
# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: v1.15.0
# hooks:
# - id: mypy
# additional_dependencies:
# - "types-PyYAML"
# - "types-requests"
# - "types-setuptools"

- repo: local
hooks:
Expand Down
2 changes: 2 additions & 0 deletions dreadnode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
task_span = DEFAULT_INSTANCE.task_span
push_update = DEFAULT_INSTANCE.push_update
tag = DEFAULT_INSTANCE.tag
get_run_context = DEFAULT_INSTANCE.get_run_context
continue_run = DEFAULT_INSTANCE.continue_run

log_metric = DEFAULT_INSTANCE.log_metric
log_param = DEFAULT_INSTANCE.log_param
Expand Down
4 changes: 2 additions & 2 deletions dreadnode/data_types/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _process_audio_data(self) -> tuple[bytes, str, int | None, float | None]:
Returns:
A tuple of (audio_bytes, format_name, sample_rate, duration)
"""
if isinstance(self._data, (str, Path)) and Path(self._data).exists():
if isinstance(self._data, str | Path) and Path(self._data).exists():
return self._process_file_path()
if isinstance(self._data, np.ndarray):
return self._process_numpy_array()
Expand Down Expand Up @@ -159,7 +159,7 @@ def _generate_metadata(
"x-python-datatype": "dreadnode.Audio.bytes",
}

if isinstance(self._data, (str, Path)):
if isinstance(self._data, str | Path):
metadata["source-type"] = "file"
metadata["source-path"] = str(self._data)
elif isinstance(self._data, np.ndarray):
Expand Down
Empty file removed dreadnode/data_types/py.typed
Empty file.
59 changes: 56 additions & 3 deletions dreadnode/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from logfire._internal.exporters.remove_pending import RemovePendingSpansExporter
from logfire._internal.stack_info import get_filepath_attribute, warn_at_user_stacklevel
from logfire._internal.utils import safe_repr
from opentelemetry import propagate
from opentelemetry.exporter.otlp.proto.http import Compression
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.trace.export import BatchSpanProcessor
Expand All @@ -39,6 +40,7 @@
FileSpanExporter,
)
from dreadnode.tracing.span import (
RunContext,
RunSpan,
Span,
TaskSpan,
Expand Down Expand Up @@ -154,7 +156,7 @@ def configure(
server: The Dreadnode server URL.
token: The Dreadnode API token.
local_dir: The local directory to store data in.
project: The defautlt project name to associate all runs with.
project: The default project name to associate all runs with.
service_name: The service name to use for OpenTelemetry.
service_version: The service version to use for OpenTelemetry.
console: Whether to log span information to the console.
Expand Down Expand Up @@ -198,7 +200,7 @@ def initialize(self) -> None:
metric_readers: list[MetricReader] = []

self.server = self.server or (DEFAULT_SERVER_URL if self.token else None)
if not (self.server and self.token and self.local_dir):
if not (self.server or self.token or self.local_dir):
warn_at_user_stacklevel(
"Your current configuration won't persist run data anywhere. "
"Use `dreadnode.init(server=..., token=...)`, `dreadnode.init(local_dir=...)`, "
Expand Down Expand Up @@ -280,6 +282,7 @@ def initialize(self) -> None:
console=logfire.ConsoleOptions() if self.console is True else self.console,
scrubbing=False,
inspect_arguments=False,
distributed_tracing=False,
)
self._logfire.config.ignore_no_config = True

Expand Down Expand Up @@ -660,12 +663,16 @@ def run(
run will be associated with a default project.
autolog: Whether to automatically log task inputs, outputs, and execution metrics if unspecified.
**attributes: Additional attributes to attach to the run span.

Returns:
A RunSpan object that can be used as a context manager.
The run will automatically be completed when the context manager exits.
"""
if not self._initialized:
self.initialize()

if name is None:
name = f"{coolname.generate_slug(2)}-{random.randint(100, 999)}" # noqa: S311
name = f"{coolname.generate_slug(2)}-{random.randint(100, 999)}" # noqa: S311 # nosec

return RunSpan(
name=name,
Expand All @@ -679,6 +686,52 @@ def run(
autolog=autolog,
)

def get_run_context(self) -> RunContext:
"""
Capture the current run context for transfer to another host, thread, or process.

Use `continue_run()` to continue the run anywhere else.

Returns:
RunContext containing run state and trace propagation headers.

Raises:
RuntimeError: If called outside of an active run.
"""
if (run := current_run_span.get()) is None:
raise RuntimeError("get_run_context() must be called within a run")

# Capture OpenTelemetry trace context
trace_context: dict[str, str] = {}
propagate.inject(trace_context)

return {
"run_id": run.run_id,
"run_name": run.name,
"project": run.project,
"trace_context": trace_context,
}

def continue_run(self, run_context: RunContext) -> RunSpan:
"""
Continue a run from captured context on a remote host.

Args:
run_context: The RunContext captured from get_run_context().

Returns:
A RunSpan object that can be used as a context manager.
"""
if not self._initialized:
self.initialize()

return RunSpan.from_context(
context=run_context,
tracer=self._get_tracer(),
file_system=self._fs,
prefix_path=self._fs_prefix,
)

def tag(self, *tag: str, to: ToObject = "task-or-run") -> None:
"""
Add one or many tags to the current task or run.
Expand Down
4 changes: 2 additions & 2 deletions dreadnode/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,8 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]:
Returns:
The span associated with task execution.
"""
run = current_run_span.get()
if run is None or not run.is_recording:

if (run := current_run_span.get()) is None:
raise RuntimeError("Tasks must be executed within a run")

log_inputs = run.autolog if isinstance(self.log_inputs, Inherited) else self.log_inputs
Expand Down
2 changes: 1 addition & 1 deletion dreadnode/tracing/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

SPAN_NAMESPACE = "dreadnode"

SpanType = t.Literal["run", "task", "span", "run_update"]
SpanType = t.Literal["run", "task", "span", "run_update", "run_fragment"]

SPAN_ATTRIBUTE_VERSION = f"{SPAN_NAMESPACE}.version"
SPAN_ATTRIBUTE_TYPE = f"{SPAN_NAMESPACE}.type"
Expand Down
59 changes: 54 additions & 5 deletions dreadnode/tracing/span.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from logfire._internal.tracer import OPEN_SPANS
from logfire._internal.utils import uniquify_sequence
from opentelemetry import context as context_api
from opentelemetry import propagate
from opentelemetry import trace as trace_api
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.trace import Tracer
Expand Down Expand Up @@ -225,6 +226,15 @@ def log_event(
)


class RunContext(te.TypedDict):
"""Context for transferring and continuing runs in other places."""

run_id: str
run_name: str
project: str
trace_context: dict[str, str]


class RunUpdateSpan(Span):
def __init__(
self,
Expand Down Expand Up @@ -274,10 +284,11 @@ def __init__(
*,
params: AnyDict | None = None,
metrics: MetricDict | None = None,
run_id: str | None = None,
tags: t.Sequence[str] | None = None,
autolog: bool = True,
update_frequency: int = 5,
run_id: str | ULID | None = None,
type: SpanType = "run",
) -> None:
self.autolog = autolog
self.project = project
Expand Down Expand Up @@ -307,6 +318,8 @@ def __init__(
self._pending_object_schemas = deepcopy(self._object_schemas)

self._context_token: Token[RunSpan | None] | None = None # contextvars context
self._remote_context: dict[str, str] | None = None # remote run trace context
self._remote_token: object | None = None
self._file_system = file_system
self._prefix_path = prefix_path

Expand All @@ -315,23 +328,55 @@ def __init__(
SPAN_ATTRIBUTE_PROJECT: project,
**attributes,
}
super().__init__(name, attributes, tracer, type="run", tags=tags)
super().__init__(name, attributes, tracer, type=type, tags=tags)

@classmethod
def from_context(
cls,
context: RunContext,
tracer: Tracer,
file_system: AbstractFileSystem,
prefix_path: str,
) -> "RunSpan":
self = RunSpan(
name=f"run.{context['run_id']}.fragment",
project=context["project"],
attributes={},
tracer=tracer,
file_system=file_system,
prefix_path=prefix_path,
type="run_fragment",
run_id=context["run_id"],
)

self._remote_context = context["trace_context"]

return self

def __enter__(self) -> te.Self:
if current_run_span.get() is not None:
raise RuntimeError("You cannot start a run span within another run")

if self._remote_context is not None:
otel_context = propagate.extract(carrier=self._remote_context)
self._remote_token = context_api.attach(otel_context)
else:
super().__enter__()

self._context_token = current_run_span.set(self)
span = super().__enter__()
self.push_update(force=True)
return span

return self

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> None:
if self._remote_context is not None:
super().__enter__() # Now we can open our actually span

# When we finally close out the final span, include all the
# full data attributes, so we can skip the update spans during
# db queries later.
Expand All @@ -355,6 +400,10 @@ def __exit__(
)

super().__exit__(exc_type, exc_value, traceback)

if self._remote_token is not None:
context_api.detach(self._remote_token) # type: ignore [arg-type]

if self._context_token is not None:
current_run_span.reset(self._context_token)

Expand Down Expand Up @@ -416,7 +465,7 @@ def log_object(

# Create a composite key that represents both data and schema
hash_input = f"{data_hash}:{schema_hash}"
composite_hash = hashlib.sha1(hash_input.encode()).hexdigest()[:16] # noqa: S324
composite_hash = hashlib.sha1(hash_input.encode()).hexdigest()[:16] # noqa: S324 # nosec

# Store schema if new
if schema_hash not in self._object_schemas:
Expand Down
Loading
Loading