From 314f8a0cd656930a47511f6877742357fd031259 Mon Sep 17 00:00:00 2001 From: monoxgas Date: Mon, 5 May 2025 19:05:38 -0700 Subject: [PATCH] Make some metric key handling stricter. Move attributes for objects to refs instead of events. Add run-level autolog specs and inheritence. Better handling for metric names. --- dreadnode/main.py | 17 +++++++++------ dreadnode/object.py | 3 +++ dreadnode/task.py | 45 ++++++++++++++++++++++++++++----------- dreadnode/tracing/span.py | 27 +++++++++++++++-------- dreadnode/types.py | 8 +++++++ 5 files changed, 72 insertions(+), 28 deletions(-) diff --git a/dreadnode/main.py b/dreadnode/main.py index d3773e70..24779700 100644 --- a/dreadnode/main.py +++ b/dreadnode/main.py @@ -47,7 +47,9 @@ current_task_span, ) from dreadnode.types import ( + INHERITED, AnyDict, + Inherited, JsonDict, JsonValue, ) @@ -412,8 +414,8 @@ def task( name: str | None = None, label: str | None = None, log_params: t.Sequence[str] | bool = False, - log_inputs: t.Sequence[str] | bool = True, - log_output: bool = True, + log_inputs: t.Sequence[str] | bool | Inherited = INHERITED, + log_output: bool | Inherited = INHERITED, tags: t.Sequence[str] | None = None, **attributes: t.Any, ) -> TaskDecorator: ... @@ -426,8 +428,8 @@ def task( name: str | None = None, label: str | None = None, log_params: t.Sequence[str] | bool = False, - log_inputs: t.Sequence[str] | bool = True, - log_output: bool = True, + log_inputs: t.Sequence[str] | bool | Inherited = INHERITED, + log_output: bool | Inherited = INHERITED, tags: t.Sequence[str] | None = None, **attributes: t.Any, ) -> ScoredTaskDecorator[R]: ... @@ -439,8 +441,8 @@ def task( name: str | None = None, label: str | None = None, log_params: t.Sequence[str] | bool = False, - log_inputs: t.Sequence[str] | bool = True, - log_output: bool = True, + log_inputs: t.Sequence[str] | bool | Inherited = INHERITED, + log_output: bool | Inherited = INHERITED, tags: t.Sequence[str] | None = None, **attributes: t.Any, ) -> TaskDecorator: @@ -622,6 +624,7 @@ def run( tags: t.Sequence[str] | None = None, params: AnyDict | None = None, project: str | None = None, + autolog: bool = True, **attributes: t.Any, ) -> RunSpan: """ @@ -647,6 +650,7 @@ def run( project: The project name to associate the run with. If not provided, the project passed to `configure()` will be used, or the run will be associated with a default project. + autolog: Whether to automatically log task inputs, outputs, and execution metrics if unspecified. **attributes: Additional attributes to attach to the run span. """ if not self._initialized: @@ -664,6 +668,7 @@ def run( tags=tags, file_system=self._fs, prefix_path=self._fs_prefix, + autolog=autolog, ) @handle_internal_errors() diff --git a/dreadnode/object.py b/dreadnode/object.py index 708beecd..26e7aab9 100644 --- a/dreadnode/object.py +++ b/dreadnode/object.py @@ -1,12 +1,15 @@ import typing as t from dataclasses import dataclass +from dreadnode.types import JsonDict + @dataclass class ObjectRef: name: str label: str hash: str + attributes: JsonDict @dataclass diff --git a/dreadnode/task.py b/dreadnode/task.py index 86a2ad92..a529a6c3 100644 --- a/dreadnode/task.py +++ b/dreadnode/task.py @@ -9,6 +9,7 @@ from dreadnode.metric import Scorer, ScorerCallable from dreadnode.tracing.span import TaskSpan, current_run_span +from dreadnode.types import INHERITED, Inherited P = t.ParamSpec("P") R = t.TypeVar("R") @@ -114,9 +115,9 @@ class Task(t.Generic[P, R]): log_params: t.Sequence[str] | bool = False "Whether to log all, or specific, incoming arguments to the function as parameters." - log_inputs: t.Sequence[str] | bool = True + log_inputs: t.Sequence[str] | bool | Inherited = INHERITED "Whether to log all, or specific, incoming arguments to the function as inputs." - log_output: bool = True + log_output: bool | Inherited = INHERITED "Whether to automatically log the result of the function as an output." def __post_init__(self) -> None: @@ -239,6 +240,9 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: if run is None or not run.is_recording: raise RuntimeError("Tasks must be executed within a run") + log_inputs = run.autolog if isinstance(self.log_inputs, Inherited) else self.log_inputs + log_output = run.autolog if isinstance(self.log_output, Inherited) else self.log_output + bound_args = self._bind_args(*args, **kwargs) params_to_log = ( @@ -250,9 +254,9 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: ) inputs_to_log = ( bound_args - if self.log_inputs is True - else {k: v for k, v in bound_args.items() if k in self.log_inputs} - if self.log_inputs is not False + if log_inputs is True + else {k: v for k, v in bound_args.items() if k in log_inputs} + if log_inputs is not False else {} ) @@ -265,13 +269,16 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: run_id=run.run_id, tracer=self.tracer, ) as span: - span.run.log_metric(f"{self.label}.exec.count", 1, mode="count") + if run.autolog: + span.run.log_metric( + "count", 1, prefix=f"{self.label}.exec", mode="count", attributes={"auto": True} + ) for name, value in params_to_log.items(): span.log_param(name, value) input_object_hashes: list[str] = [ - span.log_input(name, value, label=f"{self.label}.input.{name}") + span.log_input(name, value, label=f"{self.label}.input.{name}", auto=True) for name, value in inputs_to_log.items() ] @@ -280,17 +287,29 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: if inspect.isawaitable(output): output = await output except Exception: - span.run.log_metric(f"{self.label}.exec.success_rate", 0, mode="avg") + if run.autolog: + span.run.log_metric( + "success_rate", + 0, + prefix=f"{self.label}.exec", + mode="avg", + attributes={"auto": True}, + ) raise - span.run.log_metric(f"{self.label}.exec.success_rate", 1, mode="avg") + if run.autolog: + span.run.log_metric( + "success_rate", + 1, + prefix=f"{self.label}.exec", + mode="avg", + attributes={"auto": True}, + ) span.output = output - if self.log_output: + if log_output: output_object_hash = span.log_output( - "output", - output, - label=f"{self.label}.output", + "output", output, label=f"{self.label}.output", auto=True ) # Link the output to the inputs diff --git a/dreadnode/tracing/span.py b/dreadnode/tracing/span.py index 4e611ab7..a14b2125 100644 --- a/dreadnode/tracing/span.py +++ b/dreadnode/tracing/span.py @@ -250,11 +250,15 @@ def __init__( tracer: Tracer, file_system: AbstractFileSystem, prefix_path: str, + *, params: AnyDict | None = None, metrics: MetricDict | None = None, run_id: str | None = None, tags: t.Sequence[str] | None = None, + autolog: bool = True, ) -> None: + self.autolog = autolog + self._params = params or {} self._metrics = metrics or {} self._objects: dict[str, Object] = {} @@ -486,9 +490,8 @@ def log_input( value, label=label, event_name=EVENT_NAME_OBJECT_INPUT, - **attributes, ) - self._inputs.append(ObjectRef(name, label=label, hash=hash_)) + self._inputs.append(ObjectRef(name, label=label, hash=hash_, attributes=attributes)) def log_artifact( self, @@ -528,6 +531,7 @@ def log_metric( origin: t.Any | None = None, timestamp: datetime | None = None, mode: MetricAggMode | None = None, + prefix: str | None = None, attributes: JsonDict | None = None, ) -> None: ... @@ -539,6 +543,7 @@ def log_metric( *, origin: t.Any | None = None, mode: MetricAggMode | None = None, + prefix: str | None = None, ) -> None: ... def log_metric( @@ -550,6 +555,7 @@ def log_metric( origin: t.Any | None = None, timestamp: datetime | None = None, mode: MetricAggMode | None = None, + prefix: str | None = None, attributes: JsonDict | None = None, ) -> None: metric = ( @@ -560,6 +566,10 @@ def log_metric( ) ) + key = re.sub(r"[^\w/]+", "_", key.lower()) + if prefix is not None: + key = f"{prefix}.{key}" + if origin is not None: origin_hash = self.log_object( origin, @@ -590,9 +600,8 @@ def log_output( value, label=label, event_name=EVENT_NAME_OBJECT_OUTPUT, - **attributes, ) - self._outputs.append(ObjectRef(name, label=label, hash=hash_)) + self._outputs.append(ObjectRef(name, label=label, hash=hash_, attributes=attributes)) class TaskSpan(Span, t.Generic[R]): @@ -694,9 +703,8 @@ def log_output( value, label=label, event_name=EVENT_NAME_OBJECT_OUTPUT, - **attributes, ) - self._outputs.append(ObjectRef(name, label=label, hash=hash_)) + self._outputs.append(ObjectRef(name, label=label, hash=hash_, attributes=attributes)) return hash_ @property @@ -726,9 +734,8 @@ def log_input( value, label=label, event_name=EVENT_NAME_OBJECT_INPUT, - **attributes, ) - self._inputs.append(ObjectRef(name, label=label, hash=hash_)) + self._inputs.append(ObjectRef(name, label=label, hash=hash_, attributes=attributes)) return hash_ @property @@ -777,6 +784,8 @@ def log_metric( ) ) + key = re.sub(r"[^\w/]+", "_", key.lower()) + if origin is not None: origin_hash = self.run.log_object( origin, @@ -795,7 +804,7 @@ def log_metric( # # Don't include `source` and `mode` as we handled it here. if (run := current_run_span.get()) is not None: - run.log_metric(f"{self._label}.{key}", metric) + run.log_metric(key, metric, prefix=self._label) def get_average_metric_value(self, key: str | None = None) -> float: metrics = ( diff --git a/dreadnode/types.py b/dreadnode/types.py index 030b3d35..005b8dc0 100644 --- a/dreadnode/types.py +++ b/dreadnode/types.py @@ -23,3 +23,11 @@ def __bool__(self) -> t.Literal[False]: UNSET: Unset = Unset() + + +class Inherited: + def __repr__(self) -> str: + return "Inherited" + + +INHERITED: Inherited = Inherited()