diff --git a/dreadnode/api/util.py b/dreadnode/api/util.py index bf8b6ebc..004e3a91 100644 --- a/dreadnode/api/util.py +++ b/dreadnode/api/util.py @@ -59,11 +59,13 @@ def process_run(run: RawRun) -> Run: for references, converted in ((run.inputs, inputs), (run.outputs, outputs)): for ref in references: if (_object := run.objects.get(ref.hash)) is None: - logger.error("Object %s not found in run %s", ref.hash, run.id) + if run.status != "pending": # In-progress runs may not have all the objects ready + logger.error("Object %s not found in run %s", ref.hash, run.id) continue if (_schema := run.object_schemas.get(_object.schema_hash)) is None: - logger.error("Schema for object %s not found in run %s", ref.hash, run.id) + if run.status != "pending": + logger.error("Schema for object %s not found in run %s", ref.hash, run.id) continue if isinstance(_object, RawObjectVal): @@ -123,11 +125,13 @@ def process_task(task: RawTask, run: RawRun) -> Task: continue if (_object := run.objects.get(ref.hash)) is None: - logger.error("Object %s not found in run %s", ref.hash, run.id) + if run.status != "pending": + logger.error("Object %s not found in run %s", ref.hash, run.id) continue if (_schema := run.object_schemas.get(_object.schema_hash)) is None: - logger.error("Schema for object %s not found in run %s", ref.hash, run.id) + if run.status != "pending": + logger.error("Schema for object %s not found in run %s", ref.hash, run.id) continue if isinstance(_object, RawObjectVal): diff --git a/dreadnode/main.py b/dreadnode/main.py index ccefdb85..fec703ec 100644 --- a/dreadnode/main.py +++ b/dreadnode/main.py @@ -230,7 +230,7 @@ def initialize(self) -> None: self._api.list_projects() except Exception as e: raise RuntimeError( - "Failed to connect to the Dreadnode server.", + f"Failed to connect to the Dreadnode server: {e}", ) from e headers = {"User-Agent": f"dreadnode/{VERSION}", "X-Api-Key": self.token} @@ -707,11 +707,11 @@ def tag(self, *tag: str, to: ToObject = "task-or-run") -> None: @handle_internal_errors() def push_update(self) -> None: """ - Push any pending metric or parameter data to the server. + Push any pending run data to the server before run completion. This is useful for ensuring that the UI is up to date with the - latest data. Otherwise, all data for the run will be pushed - automatically when the run is closed. + latest data. Data is automatically pushed periodically, but + you can call this method to force a push. Example: ``` @@ -719,11 +719,13 @@ def push_update(self) -> None: dreadnode.log_params(...) dreadnode.log_metric(...) dreadnode.push_update() + + # do more work """ if (run := current_run_span.get()) is None: raise RuntimeError("Run updates must be pushed within a run") - run.push_update() + run.push_update(force=True) @handle_internal_errors() def log_param( diff --git a/dreadnode/task.py b/dreadnode/task.py index a529a6c3..c9b2b104 100644 --- a/dreadnode/task.py +++ b/dreadnode/task.py @@ -320,6 +320,9 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: metric = await scorer(output) span.log_metric(scorer.name, metric, origin=output) + # Trigger a run update whenever a task completes + run.push_update() + return span async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: diff --git a/dreadnode/tracing/span.py b/dreadnode/tracing/span.py index 49b2dcef..13d68642 100644 --- a/dreadnode/tracing/span.py +++ b/dreadnode/tracing/span.py @@ -1,5 +1,6 @@ import hashlib import logging +import time import types import typing as t from contextvars import ContextVar, Token @@ -233,19 +234,30 @@ def __init__( *, metrics: MetricDict | None = None, params: JsonDict | None = None, - inputs: JsonDict | None = None, + inputs: list[ObjectRef] | None = None, + outputs: list[ObjectRef] | None = None, + objects: dict[str, Object] | None = None, + object_schemas: dict[str, JsonDict] | None = None, ) -> None: attributes: AnyDict = { SPAN_ATTRIBUTE_RUN_ID: run_id, SPAN_ATTRIBUTE_PROJECT: project, + **({SPAN_ATTRIBUTE_METRICS: metrics} if metrics else {}), + **({SPAN_ATTRIBUTE_PARAMS: params} if params else {}), + **({SPAN_ATTRIBUTE_INPUTS: inputs} if inputs else {}), + **({SPAN_ATTRIBUTE_OUTPUTS: outputs} if outputs else {}), + **({SPAN_ATTRIBUTE_OBJECTS: objects} if objects else {}), + **({SPAN_ATTRIBUTE_OBJECT_SCHEMAS: object_schemas} if object_schemas else {}), } - if metrics: - attributes[SPAN_ATTRIBUTE_METRICS] = metrics - if params: - attributes[SPAN_ATTRIBUTE_PARAMS] = params - if inputs: - attributes[SPAN_ATTRIBUTE_INPUTS] = inputs + # Mark objects and schemas as large attributes if present + if objects or object_schemas: + large_attrs = [] + if objects: + large_attrs.append(SPAN_ATTRIBUTE_OBJECTS) + if object_schemas: + large_attrs.append(SPAN_ATTRIBUTE_OBJECT_SCHEMAS) + attributes[SPAN_ATTRIBUTE_LARGE_ATTRIBUTES] = large_attrs super().__init__(f"run.{run_id}.update", attributes, tracer, type="run_update") @@ -265,8 +277,10 @@ def __init__( run_id: str | None = None, tags: t.Sequence[str] | None = None, autolog: bool = True, + update_frequency: int = 5, ) -> None: self.autolog = autolog + self.project = project self._params = params or {} self._metrics = metrics or {} @@ -281,10 +295,16 @@ def __init__( storage=self._artifact_storage, prefix_path=prefix_path, ) - self.project = project - self._last_pushed_params = deepcopy(self._params) - self._last_pushed_metrics = deepcopy(self._metrics) + # Update mechanics + self._last_update_time = time.time() + self._update_frequency = update_frequency + self._pending_params = deepcopy(self._params) + self._pending_inputs = deepcopy(self._inputs) + self._pending_outputs = deepcopy(self._outputs) + self._pending_metrics = deepcopy(self._metrics) + self._pending_objects = deepcopy(self._objects) + self._pending_object_schemas = deepcopy(self._object_schemas) self._context_token: Token[RunSpan | None] | None = None # contextvars context self._file_system = file_system @@ -293,8 +313,6 @@ def __init__( attributes = { SPAN_ATTRIBUTE_RUN_ID: str(run_id or ULID()), SPAN_ATTRIBUTE_PROJECT: project, - SPAN_ATTRIBUTE_PARAMS: self._params, - SPAN_ATTRIBUTE_METRICS: self._metrics, **attributes, } super().__init__(name, attributes, tracer, type="run", tags=tags) @@ -304,7 +322,9 @@ def __enter__(self) -> te.Self: raise RuntimeError("You cannot start a run span within another run") self._context_token = current_run_span.set(self) - return super().__enter__() + span = super().__enter__() + self.push_update(force=True) + return span def __exit__( self, @@ -312,7 +332,10 @@ def __exit__( exc_value: BaseException | None, traceback: types.TracebackType | None, ) -> None: - self.set_attribute(SPAN_ATTRIBUTE_PARAMS, self._params) + # When we finally close out the final span, include all the + # full data attributes, so we can skip the update spans during + # db queries later. + self.set_attribute(SPAN_ATTRIBUTE_PARAMS, self._params, schema=False) self.set_attribute(SPAN_ATTRIBUTE_INPUTS, self._inputs, schema=False) self.set_attribute(SPAN_ATTRIBUTE_OUTPUTS, self._outputs, schema=False) self.set_attribute(SPAN_ATTRIBUTE_METRICS, self._metrics, schema=False) @@ -335,32 +358,46 @@ def __exit__( if self._context_token is not None: current_run_span.reset(self._context_token) - def push_update(self) -> None: + def push_update(self, *, force: bool = False) -> None: if self._span is None: return - metrics: MetricDict | None = None - if self._last_pushed_metrics != self._metrics: - metrics = self._metrics - self._last_pushed_metrics = deepcopy(self._metrics) - - params: JsonDict | None = None - if self._last_pushed_params != self._params: - params = self._params - self._last_pushed_params = deepcopy(self._params) + current_time = time.time() + force_update = force or (current_time - self._last_update_time >= self._update_frequency) + should_update = force_update and ( + self._pending_params + or self._pending_inputs + or self._pending_outputs + or self._pending_metrics + or self._pending_objects + or self._pending_object_schemas + ) - if metrics is None and params is None: + if not should_update: return with RunUpdateSpan( run_id=self.run_id, project=self.project, tracer=self._tracer, - params=params, - metrics=metrics, + metrics=self._pending_metrics if self._pending_metrics else None, + params=self._pending_params if self._pending_params else None, + inputs=self._pending_inputs if self._pending_inputs else None, + outputs=self._pending_outputs if self._pending_outputs else None, + objects=self._pending_objects if self._pending_objects else None, + object_schemas=self._pending_object_schemas if self._pending_object_schemas else None, ): pass + self._pending_metrics.clear() + self._pending_params.clear() + self._pending_inputs.clear() + self._pending_outputs.clear() + self._pending_objects.clear() + self._pending_object_schemas.clear() + + self._last_update_time = current_time + @property def run_id(self) -> str: return str(self.get_attribute(SPAN_ATTRIBUTE_RUN_ID, "")) @@ -384,6 +421,7 @@ def log_object( # Store schema if new if schema_hash not in self._object_schemas: self._object_schemas[schema_hash] = serialized.schema + self._pending_object_schemas[schema_hash] = serialized.schema # Check if we already have this exact composite hash if composite_hash not in self._objects: @@ -392,8 +430,7 @@ def log_object( # Store with composite hash so we can look it up by the combination self._objects[composite_hash] = obj - - object_ = self._objects[composite_hash] + self._pending_objects[composite_hash] = obj # Build event attributes, use composite hash in events event_attributes = { @@ -407,7 +444,9 @@ def log_object( event_attributes[EVENT_ATTRIBUTE_OBJECT_LABEL] = label self.log_event(name=event_name, attributes=event_attributes) - return object_.hash + self.push_update() + + return composite_hash def _store_file_by_hash(self, data: bytes, full_path: str) -> str: """ @@ -488,9 +527,10 @@ def log_param(self, key: str, value: t.Any) -> None: def log_params(self, **params: t.Any) -> None: for key, value in params.items(): self._params[key] = value + self._pending_params[key] = value - # Always push updates for run params - self.push_update() + # Params should get pushed immediately + self.push_update(force=True) @property def inputs(self) -> AnyDict: @@ -510,7 +550,9 @@ def log_input( label=label, event_name=EVENT_NAME_OBJECT_INPUT, ) - self._inputs.append(ObjectRef(name, label=label, hash=hash_, attributes=attributes)) + object_ref = ObjectRef(name, label=label, hash=hash_, attributes=attributes) + self._inputs.append(object_ref) + self._pending_inputs.append(object_ref) def log_artifact( self, @@ -529,11 +571,8 @@ def log_artifact( Raises: FileNotFoundError: If the path doesn't exist """ - artifact_tree = self._artifact_tree_builder.process_artifact(local_uri) - self._artifact_merger.add_tree(artifact_tree) - self._artifacts = self._artifact_merger.get_merged_trees() @property @@ -601,6 +640,7 @@ def log_metric( if mode is not None: metric = metric.apply_mode(mode, metrics) metrics.append(metric) + self._pending_metrics.setdefault(key, []).append(metric) return metric @@ -622,7 +662,9 @@ def log_output( label=label, event_name=EVENT_NAME_OBJECT_OUTPUT, ) - self._outputs.append(ObjectRef(name, label=label, hash=hash_, attributes=attributes)) + object_ref = ObjectRef(name, label=label, hash=hash_, attributes=attributes) + self._outputs.append(object_ref) + self._pending_outputs.append(object_ref) class TaskSpan(Span, t.Generic[R]):