Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions dreadnode/api/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,13 @@ def process_run(run: RawRun) -> Run:
for references, converted in ((run.inputs, inputs), (run.outputs, outputs)):
for ref in references:
if (_object := run.objects.get(ref.hash)) is None:
logger.error("Object %s not found in run %s", ref.hash, run.id)
if run.status != "pending": # In-progress runs may not have all the objects ready
logger.error("Object %s not found in run %s", ref.hash, run.id)
continue

if (_schema := run.object_schemas.get(_object.schema_hash)) is None:
logger.error("Schema for object %s not found in run %s", ref.hash, run.id)
if run.status != "pending":
logger.error("Schema for object %s not found in run %s", ref.hash, run.id)
continue

if isinstance(_object, RawObjectVal):
Expand Down Expand Up @@ -123,11 +125,13 @@ def process_task(task: RawTask, run: RawRun) -> Task:
continue

if (_object := run.objects.get(ref.hash)) is None:
logger.error("Object %s not found in run %s", ref.hash, run.id)
if run.status != "pending":
logger.error("Object %s not found in run %s", ref.hash, run.id)
continue

if (_schema := run.object_schemas.get(_object.schema_hash)) is None:
logger.error("Schema for object %s not found in run %s", ref.hash, run.id)
if run.status != "pending":
logger.error("Schema for object %s not found in run %s", ref.hash, run.id)
continue

if isinstance(_object, RawObjectVal):
Expand Down
12 changes: 7 additions & 5 deletions dreadnode/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def initialize(self) -> None:
self._api.list_projects()
except Exception as e:
raise RuntimeError(
"Failed to connect to the Dreadnode server.",
f"Failed to connect to the Dreadnode server: {e}",
) from e

headers = {"User-Agent": f"dreadnode/{VERSION}", "X-Api-Key": self.token}
Expand Down Expand Up @@ -707,23 +707,25 @@ def tag(self, *tag: str, to: ToObject = "task-or-run") -> None:
@handle_internal_errors()
def push_update(self) -> None:
"""
Push any pending metric or parameter data to the server.
Push any pending run data to the server before run completion.

This is useful for ensuring that the UI is up to date with the
latest data. Otherwise, all data for the run will be pushed
automatically when the run is closed.
latest data. Data is automatically pushed periodically, but
you can call this method to force a push.

Example:
```
with dreadnode.run("my_run"):
dreadnode.log_params(...)
dreadnode.log_metric(...)
dreadnode.push_update()

# do more work
"""
if (run := current_run_span.get()) is None:
raise RuntimeError("Run updates must be pushed within a run")

run.push_update()
run.push_update(force=True)

@handle_internal_errors()
def log_param(
Expand Down
3 changes: 3 additions & 0 deletions dreadnode/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,9 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]:
metric = await scorer(output)
span.log_metric(scorer.name, metric, origin=output)

# Trigger a run update whenever a task completes
run.push_update()

return span

async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
Expand Down
116 changes: 79 additions & 37 deletions dreadnode/tracing/span.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import hashlib
import logging
import time
import types
import typing as t
from contextvars import ContextVar, Token
Expand Down Expand Up @@ -233,19 +234,30 @@ def __init__(
*,
metrics: MetricDict | None = None,
params: JsonDict | None = None,
inputs: JsonDict | None = None,
inputs: list[ObjectRef] | None = None,
outputs: list[ObjectRef] | None = None,
objects: dict[str, Object] | None = None,
object_schemas: dict[str, JsonDict] | None = None,
) -> None:
attributes: AnyDict = {
SPAN_ATTRIBUTE_RUN_ID: run_id,
SPAN_ATTRIBUTE_PROJECT: project,
**({SPAN_ATTRIBUTE_METRICS: metrics} if metrics else {}),
**({SPAN_ATTRIBUTE_PARAMS: params} if params else {}),
**({SPAN_ATTRIBUTE_INPUTS: inputs} if inputs else {}),
**({SPAN_ATTRIBUTE_OUTPUTS: outputs} if outputs else {}),
**({SPAN_ATTRIBUTE_OBJECTS: objects} if objects else {}),
**({SPAN_ATTRIBUTE_OBJECT_SCHEMAS: object_schemas} if object_schemas else {}),
}

if metrics:
attributes[SPAN_ATTRIBUTE_METRICS] = metrics
if params:
attributes[SPAN_ATTRIBUTE_PARAMS] = params
if inputs:
attributes[SPAN_ATTRIBUTE_INPUTS] = inputs
# Mark objects and schemas as large attributes if present
if objects or object_schemas:
large_attrs = []
if objects:
large_attrs.append(SPAN_ATTRIBUTE_OBJECTS)
if object_schemas:
large_attrs.append(SPAN_ATTRIBUTE_OBJECT_SCHEMAS)
attributes[SPAN_ATTRIBUTE_LARGE_ATTRIBUTES] = large_attrs

super().__init__(f"run.{run_id}.update", attributes, tracer, type="run_update")

Expand All @@ -265,8 +277,10 @@ def __init__(
run_id: str | None = None,
tags: t.Sequence[str] | None = None,
autolog: bool = True,
update_frequency: int = 5,
) -> None:
self.autolog = autolog
self.project = project

self._params = params or {}
self._metrics = metrics or {}
Expand All @@ -281,10 +295,16 @@ def __init__(
storage=self._artifact_storage,
prefix_path=prefix_path,
)
self.project = project

self._last_pushed_params = deepcopy(self._params)
self._last_pushed_metrics = deepcopy(self._metrics)
# Update mechanics
self._last_update_time = time.time()
self._update_frequency = update_frequency
self._pending_params = deepcopy(self._params)
self._pending_inputs = deepcopy(self._inputs)
self._pending_outputs = deepcopy(self._outputs)
self._pending_metrics = deepcopy(self._metrics)
self._pending_objects = deepcopy(self._objects)
self._pending_object_schemas = deepcopy(self._object_schemas)

self._context_token: Token[RunSpan | None] | None = None # contextvars context
self._file_system = file_system
Expand All @@ -293,8 +313,6 @@ def __init__(
attributes = {
SPAN_ATTRIBUTE_RUN_ID: str(run_id or ULID()),
SPAN_ATTRIBUTE_PROJECT: project,
SPAN_ATTRIBUTE_PARAMS: self._params,
SPAN_ATTRIBUTE_METRICS: self._metrics,
**attributes,
}
super().__init__(name, attributes, tracer, type="run", tags=tags)
Expand All @@ -304,15 +322,20 @@ def __enter__(self) -> te.Self:
raise RuntimeError("You cannot start a run span within another run")

self._context_token = current_run_span.set(self)
return super().__enter__()
span = super().__enter__()
self.push_update(force=True)
return span

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> None:
self.set_attribute(SPAN_ATTRIBUTE_PARAMS, self._params)
# When we finally close out the final span, include all the
# full data attributes, so we can skip the update spans during
# db queries later.
self.set_attribute(SPAN_ATTRIBUTE_PARAMS, self._params, schema=False)
self.set_attribute(SPAN_ATTRIBUTE_INPUTS, self._inputs, schema=False)
self.set_attribute(SPAN_ATTRIBUTE_OUTPUTS, self._outputs, schema=False)
self.set_attribute(SPAN_ATTRIBUTE_METRICS, self._metrics, schema=False)
Expand All @@ -335,32 +358,46 @@ def __exit__(
if self._context_token is not None:
current_run_span.reset(self._context_token)

def push_update(self) -> None:
def push_update(self, *, force: bool = False) -> None:
if self._span is None:
return

metrics: MetricDict | None = None
if self._last_pushed_metrics != self._metrics:
metrics = self._metrics
self._last_pushed_metrics = deepcopy(self._metrics)

params: JsonDict | None = None
if self._last_pushed_params != self._params:
params = self._params
self._last_pushed_params = deepcopy(self._params)
current_time = time.time()
force_update = force or (current_time - self._last_update_time >= self._update_frequency)
should_update = force_update and (
self._pending_params
or self._pending_inputs
or self._pending_outputs
or self._pending_metrics
or self._pending_objects
or self._pending_object_schemas
)

if metrics is None and params is None:
if not should_update:
return

with RunUpdateSpan(
run_id=self.run_id,
project=self.project,
tracer=self._tracer,
params=params,
metrics=metrics,
metrics=self._pending_metrics if self._pending_metrics else None,
params=self._pending_params if self._pending_params else None,
inputs=self._pending_inputs if self._pending_inputs else None,
outputs=self._pending_outputs if self._pending_outputs else None,
objects=self._pending_objects if self._pending_objects else None,
object_schemas=self._pending_object_schemas if self._pending_object_schemas else None,
):
pass

self._pending_metrics.clear()
self._pending_params.clear()
self._pending_inputs.clear()
self._pending_outputs.clear()
self._pending_objects.clear()
self._pending_object_schemas.clear()

self._last_update_time = current_time

@property
def run_id(self) -> str:
return str(self.get_attribute(SPAN_ATTRIBUTE_RUN_ID, ""))
Expand All @@ -384,6 +421,7 @@ def log_object(
# Store schema if new
if schema_hash not in self._object_schemas:
self._object_schemas[schema_hash] = serialized.schema
self._pending_object_schemas[schema_hash] = serialized.schema

# Check if we already have this exact composite hash
if composite_hash not in self._objects:
Expand All @@ -392,8 +430,7 @@ def log_object(

# Store with composite hash so we can look it up by the combination
self._objects[composite_hash] = obj

object_ = self._objects[composite_hash]
self._pending_objects[composite_hash] = obj

# Build event attributes, use composite hash in events
event_attributes = {
Expand All @@ -407,7 +444,9 @@ def log_object(
event_attributes[EVENT_ATTRIBUTE_OBJECT_LABEL] = label

self.log_event(name=event_name, attributes=event_attributes)
return object_.hash
self.push_update()

return composite_hash

def _store_file_by_hash(self, data: bytes, full_path: str) -> str:
"""
Expand Down Expand Up @@ -488,9 +527,10 @@ def log_param(self, key: str, value: t.Any) -> None:
def log_params(self, **params: t.Any) -> None:
for key, value in params.items():
self._params[key] = value
self._pending_params[key] = value

# Always push updates for run params
self.push_update()
# Params should get pushed immediately
self.push_update(force=True)

@property
def inputs(self) -> AnyDict:
Expand All @@ -510,7 +550,9 @@ def log_input(
label=label,
event_name=EVENT_NAME_OBJECT_INPUT,
)
self._inputs.append(ObjectRef(name, label=label, hash=hash_, attributes=attributes))
object_ref = ObjectRef(name, label=label, hash=hash_, attributes=attributes)
self._inputs.append(object_ref)
self._pending_inputs.append(object_ref)

def log_artifact(
self,
Expand All @@ -529,11 +571,8 @@ def log_artifact(
Raises:
FileNotFoundError: If the path doesn't exist
"""

artifact_tree = self._artifact_tree_builder.process_artifact(local_uri)

self._artifact_merger.add_tree(artifact_tree)

self._artifacts = self._artifact_merger.get_merged_trees()

@property
Expand Down Expand Up @@ -601,6 +640,7 @@ def log_metric(
if mode is not None:
metric = metric.apply_mode(mode, metrics)
metrics.append(metric)
self._pending_metrics.setdefault(key, []).append(metric)

return metric

Expand All @@ -622,7 +662,9 @@ def log_output(
label=label,
event_name=EVENT_NAME_OBJECT_OUTPUT,
)
self._outputs.append(ObjectRef(name, label=label, hash=hash_, attributes=attributes))
object_ref = ObjectRef(name, label=label, hash=hash_, attributes=attributes)
self._outputs.append(object_ref)
self._pending_outputs.append(object_ref)


class TaskSpan(Span, t.Generic[R]):
Expand Down