From 56aedb452e046c92b621b2539a4ea05f31cba35c Mon Sep 17 00:00:00 2001 From: monoxgas Date: Wed, 14 May 2025 18:13:25 -0600 Subject: [PATCH 1/4] Refactored run update mechanics for more data handling and higher frequency --- dreadnode/api/util.py | 12 ++-- dreadnode/main.py | 12 ++-- dreadnode/task.py | 3 + dreadnode/tracing/span.py | 117 +++++++++++++++++++++++++------------- 4 files changed, 97 insertions(+), 47 deletions(-) diff --git a/dreadnode/api/util.py b/dreadnode/api/util.py index bf8b6ebc..004e3a91 100644 --- a/dreadnode/api/util.py +++ b/dreadnode/api/util.py @@ -59,11 +59,13 @@ def process_run(run: RawRun) -> Run: for references, converted in ((run.inputs, inputs), (run.outputs, outputs)): for ref in references: if (_object := run.objects.get(ref.hash)) is None: - logger.error("Object %s not found in run %s", ref.hash, run.id) + if run.status != "pending": # In-progress runs may not have all the objects ready + logger.error("Object %s not found in run %s", ref.hash, run.id) continue if (_schema := run.object_schemas.get(_object.schema_hash)) is None: - logger.error("Schema for object %s not found in run %s", ref.hash, run.id) + if run.status != "pending": + logger.error("Schema for object %s not found in run %s", ref.hash, run.id) continue if isinstance(_object, RawObjectVal): @@ -123,11 +125,13 @@ def process_task(task: RawTask, run: RawRun) -> Task: continue if (_object := run.objects.get(ref.hash)) is None: - logger.error("Object %s not found in run %s", ref.hash, run.id) + if run.status != "pending": + logger.error("Object %s not found in run %s", ref.hash, run.id) continue if (_schema := run.object_schemas.get(_object.schema_hash)) is None: - logger.error("Schema for object %s not found in run %s", ref.hash, run.id) + if run.status != "pending": + logger.error("Schema for object %s not found in run %s", ref.hash, run.id) continue if isinstance(_object, RawObjectVal): diff --git a/dreadnode/main.py b/dreadnode/main.py index f47461ba..c9ded67e 100644 --- a/dreadnode/main.py +++ b/dreadnode/main.py @@ -232,7 +232,7 @@ def initialize(self) -> None: self._api.list_projects() except Exception as e: raise RuntimeError( - "Failed to authenticate with the provided server and token", + f"Failed to connect to Dreadnode: {e}", ) from e headers = {"User-Agent": f"dreadnode/{VERSION}", "X-Api-Key": self.token} @@ -684,11 +684,11 @@ def run( @handle_internal_errors() def push_update(self) -> None: """ - Push any pending metric or parameter data to the server. + Push any pending run data to the server before run completion. This is useful for ensuring that the UI is up to date with the - latest data. Otherwise, all data for the run will be pushed - automatically when the run is closed. + latest data. Data is automatically pushed periodically, but + you can call this method to force a push. Example: ``` @@ -696,11 +696,13 @@ def push_update(self) -> None: dreadnode.log_params(...) dreadnode.log_metric(...) dreadnode.push_update() + + # do more work """ if (run := current_run_span.get()) is None: raise RuntimeError("Run updates must be pushed within a run") - run.push_update() + run.push_update(force=True) @handle_internal_errors() def log_param( diff --git a/dreadnode/task.py b/dreadnode/task.py index a529a6c3..c9b2b104 100644 --- a/dreadnode/task.py +++ b/dreadnode/task.py @@ -320,6 +320,9 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: metric = await scorer(output) span.log_metric(scorer.name, metric, origin=output) + # Trigger a run update whenever a task completes + run.push_update() + return span async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: diff --git a/dreadnode/tracing/span.py b/dreadnode/tracing/span.py index a14b2125..45d7b906 100644 --- a/dreadnode/tracing/span.py +++ b/dreadnode/tracing/span.py @@ -1,5 +1,6 @@ import logging import re +import time import types import typing as t from contextvars import ContextVar, Token @@ -224,19 +225,30 @@ def __init__( *, metrics: MetricDict | None = None, params: JsonDict | None = None, - inputs: JsonDict | None = None, + inputs: list[ObjectRef] | None = None, + outputs: list[ObjectRef] | None = None, + objects: dict[str, Object] | None = None, + object_schemas: dict[str, JsonDict] | None = None, ) -> None: attributes: AnyDict = { SPAN_ATTRIBUTE_RUN_ID: run_id, SPAN_ATTRIBUTE_PROJECT: project, + **({SPAN_ATTRIBUTE_METRICS: metrics} if metrics else {}), + **({SPAN_ATTRIBUTE_PARAMS: params} if params else {}), + **({SPAN_ATTRIBUTE_INPUTS: inputs} if inputs else {}), + **({SPAN_ATTRIBUTE_OUTPUTS: outputs} if outputs else {}), + **({SPAN_ATTRIBUTE_OBJECTS: objects} if objects else {}), + **({SPAN_ATTRIBUTE_OBJECT_SCHEMAS: object_schemas} if object_schemas else {}), } - if metrics: - attributes[SPAN_ATTRIBUTE_METRICS] = metrics - if params: - attributes[SPAN_ATTRIBUTE_PARAMS] = params - if inputs: - attributes[SPAN_ATTRIBUTE_INPUTS] = inputs + # Mark objects and schemas as large attributes if present + if objects or object_schemas: + large_attrs = [] + if objects: + large_attrs.append(SPAN_ATTRIBUTE_OBJECTS) + if object_schemas: + large_attrs.append(SPAN_ATTRIBUTE_OBJECT_SCHEMAS) + attributes[SPAN_ATTRIBUTE_LARGE_ATTRIBUTES] = large_attrs super().__init__(f"run.{run_id}.update", attributes, tracer, type="run_update") @@ -256,8 +268,10 @@ def __init__( run_id: str | None = None, tags: t.Sequence[str] | None = None, autolog: bool = True, + update_frequency: int = 5, ) -> None: self.autolog = autolog + self.project = project self._params = params or {} self._metrics = metrics or {} @@ -272,10 +286,17 @@ def __init__( storage=self._artifact_storage, prefix_path=prefix_path, ) - self.project = project - self._last_pushed_params = deepcopy(self._params) - self._last_pushed_metrics = deepcopy(self._metrics) + # Update mechanics + self._last_update_time = time.time() + self._update_frequency = update_frequency + self._pending_params = deepcopy(self._params) + self._pending_inputs = deepcopy(self._inputs) + self._pending_outputs = deepcopy(self._outputs) + self._pending_artifacts = deepcopy(self._artifacts) + self._pending_metrics = deepcopy(self._metrics) + self._pending_objects = deepcopy(self._objects) + self._pending_object_schemas = deepcopy(self._object_schemas) self._context_token: Token[RunSpan | None] | None = None # contextvars context self._file_system = file_system @@ -284,8 +305,6 @@ def __init__( attributes = { SPAN_ATTRIBUTE_RUN_ID: str(run_id or ULID()), SPAN_ATTRIBUTE_PROJECT: project, - SPAN_ATTRIBUTE_PARAMS: self._params, - SPAN_ATTRIBUTE_METRICS: self._metrics, **attributes, } super().__init__(name, attributes, tracer, type="run", tags=tags) @@ -303,14 +322,14 @@ def __exit__( exc_value: BaseException | None, traceback: types.TracebackType | None, ) -> None: - self.set_attribute(SPAN_ATTRIBUTE_PARAMS, self._params) - self.set_attribute(SPAN_ATTRIBUTE_INPUTS, self._inputs, schema=False) - self.set_attribute(SPAN_ATTRIBUTE_OUTPUTS, self._outputs, schema=False) - self.set_attribute(SPAN_ATTRIBUTE_METRICS, self._metrics, schema=False) - self.set_attribute(SPAN_ATTRIBUTE_OBJECTS, self._objects, schema=False) + self.set_attribute(SPAN_ATTRIBUTE_PARAMS, self._pending_params) + self.set_attribute(SPAN_ATTRIBUTE_INPUTS, self._pending_inputs, schema=False) + self.set_attribute(SPAN_ATTRIBUTE_OUTPUTS, self._pending_outputs, schema=False) + self.set_attribute(SPAN_ATTRIBUTE_METRICS, self._pending_metrics, schema=False) + self.set_attribute(SPAN_ATTRIBUTE_OBJECTS, self._pending_objects, schema=False) self.set_attribute( SPAN_ATTRIBUTE_OBJECT_SCHEMAS, - self._object_schemas, + self._pending_object_schemas, schema=False, ) self.set_attribute(SPAN_ATTRIBUTE_ARTIFACTS, self._artifacts, schema=False) @@ -326,32 +345,47 @@ def __exit__( if self._context_token is not None: current_run_span.reset(self._context_token) - def push_update(self) -> None: + def push_update(self, *, force: bool = False) -> None: if self._span is None: return - metrics: MetricDict | None = None - if self._last_pushed_metrics != self._metrics: - metrics = self._metrics - self._last_pushed_metrics = deepcopy(self._metrics) - - params: JsonDict | None = None - if self._last_pushed_params != self._params: - params = self._params - self._last_pushed_params = deepcopy(self._params) + current_time = time.time() + force_update = force or (current_time - self._last_update_time >= self._update_frequency) + should_update = force_update and ( + self._pending_params + or self._pending_inputs + or self._pending_outputs + or self._pending_artifacts + or self._pending_metrics + or self._pending_objects + or self._pending_object_schemas + ) - if metrics is None and params is None: + if not should_update: return with RunUpdateSpan( run_id=self.run_id, project=self.project, tracer=self._tracer, - params=params, - metrics=metrics, + metrics=self._pending_metrics if self._pending_metrics else None, + params=self._pending_params if self._pending_params else None, + inputs=self._pending_inputs if self._pending_inputs else None, + outputs=self._pending_outputs if self._pending_outputs else None, + objects=self._pending_objects if self._pending_objects else None, + object_schemas=self._pending_object_schemas if self._pending_object_schemas else None, ): pass + self._pending_metrics.clear() + self._pending_params.clear() + self._pending_inputs.clear() + self._pending_outputs.clear() + self._pending_objects.clear() + self._pending_object_schemas.clear() + + self._last_update_time = current_time + @property def run_id(self) -> str: return str(self.get_attribute(SPAN_ATTRIBUTE_RUN_ID, "")) @@ -371,12 +405,14 @@ def log_object( # Store object if we haven't already if data_hash not in self._objects: self._objects[data_hash] = self._create_object(serialized) + self._pending_objects[data_hash] = self._objects[data_hash] object_ = self._objects[data_hash] # Store schema if new if schema_hash not in self._object_schemas: self._object_schemas[schema_hash] = serialized.schema + self._pending_object_schemas[schema_hash] = serialized.schema # Build event attributes event_attributes = { @@ -390,6 +426,8 @@ def log_object( event_attributes[EVENT_ATTRIBUTE_OBJECT_LABEL] = label self.log_event(name=event_name, attributes=event_attributes) + self.push_update() + return object_.hash def _store_file_by_hash(self, data: bytes, full_path: str) -> str: @@ -469,9 +507,10 @@ def log_param(self, key: str, value: t.Any) -> None: def log_params(self, **params: t.Any) -> None: for key, value in params.items(): self._params[key] = value + self._pending_params[key] = value - # Always push updates for run params - self.push_update() + # Params should get pushed immediately + self.push_update(force=True) @property def inputs(self) -> AnyDict: @@ -491,7 +530,9 @@ def log_input( label=label, event_name=EVENT_NAME_OBJECT_INPUT, ) - self._inputs.append(ObjectRef(name, label=label, hash=hash_, attributes=attributes)) + object_ref = ObjectRef(name, label=label, hash=hash_, attributes=attributes) + self._inputs.append(object_ref) + self._pending_inputs.append(object_ref) def log_artifact( self, @@ -510,11 +551,8 @@ def log_artifact( Raises: FileNotFoundError: If the path doesn't exist """ - artifact_tree = self._artifact_tree_builder.process_artifact(local_uri) - self._artifact_merger.add_tree(artifact_tree) - self._artifacts = self._artifact_merger.get_merged_trees() @property @@ -582,6 +620,7 @@ def log_metric( if mode is not None: metric = metric.apply_mode(mode, metrics) metrics.append(metric) + self._pending_metrics.setdefault(key, []).append(metric) @property def outputs(self) -> AnyDict: @@ -601,7 +640,9 @@ def log_output( label=label, event_name=EVENT_NAME_OBJECT_OUTPUT, ) - self._outputs.append(ObjectRef(name, label=label, hash=hash_, attributes=attributes)) + object_ref = ObjectRef(name, label=label, hash=hash_, attributes=attributes) + self._outputs.append(object_ref) + self._pending_outputs.append(object_ref) class TaskSpan(Span, t.Generic[R]): From 1fc8e0dcdbdaff7f65da861dfd0b4c0279fde0d6 Mon Sep 17 00:00:00 2001 From: monoxgas Date: Wed, 14 May 2025 18:20:32 -0600 Subject: [PATCH 2/4] Push an initial update right when the run starts --- dreadnode/tracing/span.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dreadnode/tracing/span.py b/dreadnode/tracing/span.py index 45d7b906..f5bbf90a 100644 --- a/dreadnode/tracing/span.py +++ b/dreadnode/tracing/span.py @@ -314,7 +314,9 @@ def __enter__(self) -> te.Self: raise RuntimeError("You cannot start a run span within another run") self._context_token = current_run_span.set(self) - return super().__enter__() + span = super().__enter__() + self.push_update(force=True) + return span def __exit__( self, From b7b89b1f6105c39b41e27c86174d5bf0056fff6f Mon Sep 17 00:00:00 2001 From: monoxgas Date: Tue, 20 May 2025 10:35:27 -0600 Subject: [PATCH 3/4] Remove unused _pending_artifact --- dreadnode/tracing/span.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/dreadnode/tracing/span.py b/dreadnode/tracing/span.py index f5bbf90a..2bf4baeb 100644 --- a/dreadnode/tracing/span.py +++ b/dreadnode/tracing/span.py @@ -293,7 +293,6 @@ def __init__( self._pending_params = deepcopy(self._params) self._pending_inputs = deepcopy(self._inputs) self._pending_outputs = deepcopy(self._outputs) - self._pending_artifacts = deepcopy(self._artifacts) self._pending_metrics = deepcopy(self._metrics) self._pending_objects = deepcopy(self._objects) self._pending_object_schemas = deepcopy(self._object_schemas) @@ -357,7 +356,6 @@ def push_update(self, *, force: bool = False) -> None: self._pending_params or self._pending_inputs or self._pending_outputs - or self._pending_artifacts or self._pending_metrics or self._pending_objects or self._pending_object_schemas From 5120858ef689f32b752cc5339db67d50086fbd34 Mon Sep 17 00:00:00 2001 From: monoxgas Date: Fri, 23 May 2025 14:07:45 -0600 Subject: [PATCH 4/4] Send full attribute data in final run span --- dreadnode/tracing/span.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/dreadnode/tracing/span.py b/dreadnode/tracing/span.py index 2bf4baeb..23964f4e 100644 --- a/dreadnode/tracing/span.py +++ b/dreadnode/tracing/span.py @@ -323,14 +323,17 @@ def __exit__( exc_value: BaseException | None, traceback: types.TracebackType | None, ) -> None: - self.set_attribute(SPAN_ATTRIBUTE_PARAMS, self._pending_params) - self.set_attribute(SPAN_ATTRIBUTE_INPUTS, self._pending_inputs, schema=False) - self.set_attribute(SPAN_ATTRIBUTE_OUTPUTS, self._pending_outputs, schema=False) - self.set_attribute(SPAN_ATTRIBUTE_METRICS, self._pending_metrics, schema=False) - self.set_attribute(SPAN_ATTRIBUTE_OBJECTS, self._pending_objects, schema=False) + # When we finally close out the final span, include all the + # full data attributes, so we can skip the update spans during + # db queries later. + self.set_attribute(SPAN_ATTRIBUTE_PARAMS, self._params, schema=False) + self.set_attribute(SPAN_ATTRIBUTE_INPUTS, self._inputs, schema=False) + self.set_attribute(SPAN_ATTRIBUTE_OUTPUTS, self._outputs, schema=False) + self.set_attribute(SPAN_ATTRIBUTE_METRICS, self._metrics, schema=False) + self.set_attribute(SPAN_ATTRIBUTE_OBJECTS, self._objects, schema=False) self.set_attribute( SPAN_ATTRIBUTE_OBJECT_SCHEMAS, - self._pending_object_schemas, + self._object_schemas, schema=False, ) self.set_attribute(SPAN_ATTRIBUTE_ARTIFACTS, self._artifacts, schema=False)