diff --git a/docs/sdk/main.mdx b/docs/sdk/main.mdx index 463cd670..3ebd624b 100644 --- a/docs/sdk/main.mdx +++ b/docs/sdk/main.mdx @@ -492,7 +492,9 @@ def initialize(self) -> None: ```python link_objects( - origin: Any, link: Any, **attributes: JsonValue + origin: Any, + link: Any, + attributes: AnyDict | None = None, ) -> None ``` @@ -520,16 +522,21 @@ with dreadnode.run("my_run"): * **`link`** (`Any`) –The linked object to link to. -* **`**attributes`** - (`JsonValue`, default: - `{}` +* **`attributes`** + (`AnyDict | None`, default: + `None` ) –Additional attributes to attach to the link. ```python @handle_internal_errors() -def link_objects(self, origin: t.Any, link: t.Any, **attributes: JsonValue) -> None: +def link_objects( + self, + origin: t.Any, + link: t.Any, + attributes: AnyDict | None = None, +) -> None: """ Associate two runtime objects with each other. @@ -549,14 +556,14 @@ def link_objects(self, origin: t.Any, link: t.Any, **attributes: JsonValue) -> N Args: origin: The origin object to link from. link: The linked object to link to. - **attributes: Additional attributes to attach to the link. + attributes: Additional attributes to attach to the link. """ if (run := current_run_span.get()) is None: raise RuntimeError("link() must be called within a run") origin_hash = run.log_object(origin) link_hash = run.log_object(link) - run.link_objects(origin_hash, link_hash, **attributes) + run.link_objects(origin_hash, link_hash, attributes=attributes) ``` @@ -662,7 +669,7 @@ log_input( *, label: str | None = None, to: ToObject = "task-or-run", - **attributes: Any, + attributes: AnyDict | None = None, ) -> None ``` @@ -696,7 +703,7 @@ def log_input( *, label: str | None = None, to: ToObject = "task-or-run", - **attributes: t.Any, + attributes: AnyDict | None = None, ) -> None: """ Log a single input to the current task or run. @@ -724,7 +731,7 @@ def log_input( if target is None: raise RuntimeError("log_inputs() must be called within a run") - target.log_input(name, value, label=label, **attributes) + target.log_input(name, value, label=label, attributes=attributes) ``` @@ -734,7 +741,7 @@ def log_input( ```python log_inputs( - to: ToObject = "task-or-run", **inputs: JsonValue + to: ToObject = "task-or-run", **inputs: Any ) -> None ``` @@ -748,7 +755,7 @@ See `log_input()` for more details. def log_inputs( self, to: ToObject = "task-or-run", - **inputs: JsonValue, + **inputs: t.Any, ) -> None: """ Log multiple inputs to the current task or run. @@ -773,7 +780,7 @@ log_metric( origin: Any | None = None, timestamp: datetime | None = None, mode: MetricAggMode | None = None, - attributes: JsonDict | None = None, + attributes: AnyDict | None = None, to: ToObject = "task-or-run", ) -> Metric ``` @@ -798,7 +805,7 @@ log_metric( origin: Any | None = None, timestamp: datetime | None = None, mode: MetricAggMode | None = None, - attributes: JsonDict | None = None, + attributes: AnyDict | None = None, to: ToObject = "task-or-run", ) -> Metric ``` @@ -863,7 +870,7 @@ with dreadnode.run("my_run"): - sum: the cumulative sum of all reported values for this metric - count: increment every time this metric is logged - disregard value * **`attributes`** - (`JsonDict | None`, default: + (`AnyDict | None`, default: `None` ) –A dictionary of additional attributes to attach to the metric. @@ -892,7 +899,7 @@ def log_metric( origin: t.Any | None = None, timestamp: datetime | None = None, mode: MetricAggMode | None = None, - attributes: JsonDict | None = None, + attributes: AnyDict | None = None, to: ToObject = "task-or-run", ) -> Metric: """ @@ -971,7 +978,7 @@ log_metrics( step: int = 0, timestamp: datetime | None = None, mode: MetricAggMode | None = None, - attributes: JsonDict | None = None, + attributes: AnyDict | None = None, to: ToObject = "task-or-run", ) -> list[Metric] ``` @@ -983,7 +990,7 @@ log_metrics( step: int = 0, timestamp: datetime | None = None, mode: MetricAggMode | None = None, - attributes: JsonDict | None = None, + attributes: AnyDict | None = None, to: ToObject = "task-or-run", ) -> list[Metric] ``` @@ -995,7 +1002,7 @@ log_metrics( step: int = 0, timestamp: datetime | None = None, mode: MetricAggMode | None = None, - attributes: JsonDict | None = None, + attributes: AnyDict | None = None, to: ToObject = "task-or-run", ) -> list[Metric] ``` @@ -1050,7 +1057,7 @@ dreadnode.log_metrics( ) –Default aggregation mode for metrics if not supplied. * **`attributes`** - (`JsonDict | None`, default: + (`AnyDict | None`, default: `None` ) –Default attributes for metrics if not supplied. @@ -1077,7 +1084,7 @@ def log_metrics( step: int = 0, timestamp: datetime | None = None, mode: MetricAggMode | None = None, - attributes: JsonDict | None = None, + attributes: AnyDict | None = None, to: ToObject = "task-or-run", ) -> list[Metric]: """ @@ -1173,7 +1180,7 @@ log_output( *, label: str | None = None, to: ToObject = "task-or-run", - **attributes: JsonValue, + attributes: AnyDict | None = None, ) -> None ``` @@ -1197,6 +1204,31 @@ with dreadnode.run("my_run"): dreadnode.log_output("other", 123) ``` +**Parameters:** + +* **`name`** + (`str`) + –The name of the output. +* **`value`** + (`Any`) + –The value of the output. +* **`label`** + (`str | None`, default: + `None` + ) + –An optional label for the output, useful for filtering in the UI. +* **`to`** + (`ToObject`, default: + `'task-or-run'` + ) + –The target object to log the output to. Can be "task-or-run" or "run". + Defaults to "task-or-run". If "task-or-run", the output will be logged + to the current task or run, whichever is the nearest ancestor. +* **`attributes`** + (`AnyDict | None`, default: + `None` + ) + –Additional attributes to attach to the output. ```python @@ -1208,7 +1240,7 @@ def log_output( *, label: str | None = None, to: ToObject = "task-or-run", - **attributes: JsonValue, + attributes: AnyDict | None = None, ) -> None: """ Log a single output to the current task or run. @@ -1229,6 +1261,15 @@ def log_output( dreadnode.log_output("other", 123) ~~~ + + Args: + name: The name of the output. + value: The value of the output. + label: An optional label for the output, useful for filtering in the UI. + to: The target object to log the output to. Can be "task-or-run" or "run". + Defaults to "task-or-run". If "task-or-run", the output will be logged + to the current task or run, whichever is the nearest ancestor. + attributes: Additional attributes to attach to the output. """ task = current_task_span.get() run = current_run_span.get() @@ -1239,7 +1280,7 @@ def log_output( "log_output() must be called within a run or a task", ) - target.log_output(name, value, label=label, **attributes) + target.log_output(name, value, label=label, attributes=attributes) ``` @@ -1461,7 +1502,7 @@ run( params: AnyDict | None = None, project: str | None = None, autolog: bool = True, - **attributes: Any, + attributes: AnyDict | None = None, ) -> RunSpan ``` @@ -1510,9 +1551,9 @@ with dreadnode.run("my_run"): `True` ) –Whether to automatically log task inputs, outputs, and execution metrics if otherwise unspecified. -* **`**attributes`** - (`Any`, default: - `{}` +* **`attributes`** + (`AnyDict | None`, default: + `None` ) –Additional attributes to attach to the run span. @@ -1533,7 +1574,7 @@ def run( params: AnyDict | None = None, project: str | None = None, autolog: bool = True, - **attributes: t.Any, + attributes: AnyDict | None = None, ) -> RunSpan: """ Create a new run. @@ -1559,7 +1600,7 @@ def run( the project passed to `configure()` will be used, or the run will be associated with a default project. autolog: Whether to automatically log task inputs, outputs, and execution metrics if otherwise unspecified. - **attributes: Additional attributes to attach to the run span. + attributes: Additional attributes to attach to the run span. Returns: A RunSpan object that can be used as a context manager. @@ -1594,7 +1635,7 @@ scorer( *, name: str | None = None, tags: Sequence[str] | None = None, - **attributes: Any, + attributes: AnyDict | None = None, ) -> t.Callable[[ScorerCallable[T]], Scorer[T]] ``` @@ -1629,9 +1670,9 @@ await my_task(2) `None` ) –A list of tags to attach to the scorer. -* **`**attributes`** - (`Any`, default: - `{}` +* **`attributes`** + (`AnyDict | None`, default: + `None` ) –A dictionary of attributes to attach to the scorer. @@ -1647,7 +1688,7 @@ def scorer( *, name: str | None = None, tags: t.Sequence[str] | None = None, - **attributes: t.Any, + attributes: AnyDict | None = None, ) -> t.Callable[[ScorerCallable[T]], Scorer[T]]: """ Make a scorer from a callable function. @@ -1671,7 +1712,7 @@ def scorer( Args: name: The name of the scorer. tags: A list of tags to attach to the scorer. - **attributes: A dictionary of attributes to attach to the scorer. + attributes: A dictionary of attributes to attach to the scorer. Returns: A new Scorer object. @@ -1735,7 +1776,7 @@ span( name: str, *, tags: Sequence[str] | None = None, - **attributes: Any, + attributes: AnyDict | None = None, ) -> Span ``` @@ -1763,9 +1804,9 @@ with dreadnode.span("my_span") as span: `None` ) –A list of tags to attach to the span. -* **`**attributes`** - (`Any`, default: - `{}` +* **`attributes`** + (`AnyDict | None`, default: + `None` ) –A dictionary of attributes to attach to the span. @@ -1781,7 +1822,7 @@ def span( name: str, *, tags: t.Sequence[str] | None = None, - **attributes: t.Any, + attributes: AnyDict | None = None, ) -> Span: """ Create a new OpenTelemety span. @@ -1800,7 +1841,7 @@ def span( Args: name: The name of the span. tags: A list of tags to attach to the span. - **attributes: A dictionary of attributes to attach to the span. + attributes: A dictionary of attributes to attach to the span. Returns: A Span object. @@ -1891,7 +1932,7 @@ task( log_output: bool | Inherited = INHERITED, log_execution_metrics: bool = False, tags: Sequence[str] | None = None, - **attributes: Any, + attributes: AnyDict | None = None, ) -> TaskDecorator ``` @@ -1907,7 +1948,7 @@ task( log_output: bool | Inherited = INHERITED, log_execution_metrics: bool = False, tags: Sequence[str] | None = None, - **attributes: Any, + attributes: AnyDict | None = None, ) -> ScoredTaskDecorator[R] ``` @@ -1924,7 +1965,7 @@ task( log_output: bool | Inherited = INHERITED, log_execution_metrics: bool = False, tags: Sequence[str] | None = None, - **attributes: Any, + attributes: AnyDict | None = None, ) -> TaskDecorator ``` @@ -1978,9 +2019,9 @@ await my_task(2) `None` ) –A list of tags to attach to the task span. -* **`**attributes`** - (`Any`, default: - `{}` +* **`attributes`** + (`AnyDict | None`, default: + `None` ) –A dictionary of attributes to attach to the task span. @@ -2001,7 +2042,7 @@ def task( log_output: bool | Inherited = INHERITED, log_execution_metrics: bool = False, tags: t.Sequence[str] | None = None, - **attributes: t.Any, + attributes: AnyDict | None = None, ) -> TaskDecorator: """ Create a new task from a function. @@ -2024,7 +2065,7 @@ def task( log_output: Log the result of the function as an output. log_execution_metrics: Log execution metrics for the task, such as success rate and run count. tags: A list of tags to attach to the task span. - **attributes: A dictionary of attributes to attach to the task span. + attributes: A dictionary of attributes to attach to the task span. Returns: A new Task object. @@ -2033,6 +2074,18 @@ def task( def make_task( func: t.Callable[P, t.Awaitable[R]] | t.Callable[P, R], ) -> Task[P, R]: + if isinstance(func, Task): + return func.with_( + name=name, + label=label, + log_inputs=log_inputs, + log_output=log_output, + log_execution_metrics=log_execution_metrics, + tags=tags, + attributes=attributes, + append=True, + ) + unwrapped = inspect.unwrap(func) if inspect.isgeneratorfunction(unwrapped) or inspect.isasyncgenfunction( @@ -2095,7 +2148,7 @@ task_span( *, label: str | None = None, tags: Sequence[str] | None = None, - **attributes: Any, + attributes: AnyDict | None = None, ) -> TaskSpan[t.Any] ``` @@ -2116,7 +2169,7 @@ Args: name: The name of the task. label: The label of the task - useful for filtering in the UI. tags: A list of tags to attach to the task span. -\*\*attributes: A dictionary of attributes to attach to the task span. +attributes: A dictionary of attributes to attach to the task span. **Returns:** @@ -2131,7 +2184,7 @@ def task_span( *, label: str | None = None, tags: t.Sequence[str] | None = None, - **attributes: t.Any, + attributes: AnyDict | None = None, ) -> TaskSpan[t.Any]: """ Create a task span without an explicit associated function. @@ -2149,21 +2202,20 @@ def task_span( name: The name of the task. label: The label of the task - useful for filtering in the UI. tags: A list of tags to attach to the task span. - **attributes: A dictionary of attributes to attach to the task span. + attributes: A dictionary of attributes to attach to the task span. Returns: A TaskSpan object. """ - if (run := current_run_span.get()) is None: - raise RuntimeError("Task spans must be created within a run") + run = current_run_span.get() + label = clean_str(label or name) - label = label or clean_str(name) return TaskSpan( name=name, label=label, attributes=attributes, tags=tags, - run_id=run.run_id, + run_id=run.run_id if run else "", tracer=self._get_tracer(), ) ``` diff --git a/docs/sdk/task.mdx b/docs/sdk/task.mdx index b5610e2f..4de2c1fb 100644 --- a/docs/sdk/task.mdx +++ b/docs/sdk/task.mdx @@ -287,11 +287,18 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: The span associated with task execution. """ - if (run := current_run_span.get()) is None: - raise RuntimeError("Tasks must be executed within a run") + run = current_run_span.get() - log_inputs = run.autolog if isinstance(self.log_inputs, Inherited) else self.log_inputs - log_output = run.autolog if isinstance(self.log_output, Inherited) else self.log_output + log_inputs = ( + (run.autolog if run else False) + if isinstance(self.log_inputs, Inherited) + else self.log_inputs + ) + log_output = ( + (run.autolog if run else False) + if isinstance(self.log_output, Inherited) + else self.log_output + ) bound_args = self._bind_args(*args, **kwargs) @@ -313,16 +320,22 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: label=self.label, attributes=self.attributes, tags=self.tags, - run_id=run.run_id, + run_id=run.run_id if run else "", tracer=self.tracer, ) as span: - if self.log_execution_metrics: - span.run.log_metric( - "count", 1, prefix=f"{self.label}.exec", mode="count", attributes={"auto": True} + if run and self.log_execution_metrics: + run.log_metric( + "count", + 1, + prefix=f"{self.label}.exec", + mode="count", + attributes={"auto": True}, ) input_object_hashes: list[str] = [ - span.log_input(name, value, label=f"{self.label}.input.{name}", auto=True) + span.log_input( + name, value, label=f"{self.label}.input.{name}", attributes={"auto": True} + ) for name, value in inputs_to_log.items() ] @@ -331,8 +344,8 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: if inspect.isawaitable(output): output = await output except Exception: - if self.log_execution_metrics: - span.run.log_metric( + if run and self.log_execution_metrics: + run.log_metric( "success_rate", 0, prefix=f"{self.label}.exec", @@ -341,8 +354,8 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: ) raise - if self.log_execution_metrics: - span.run.log_metric( + if run and self.log_execution_metrics: + run.log_metric( "success_rate", 1, prefix=f"{self.label}.exec", @@ -351,23 +364,28 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: ) span.output = output - if log_output and ( - not isinstance(self.log_inputs, Inherited) or seems_useful_to_serialize(output) + if ( + run + and log_output + and ( + not isinstance(self.log_inputs, Inherited) or seems_useful_to_serialize(output) + ) ): output_object_hash = span.log_output( - "output", output, label=f"{self.label}.output", auto=True + "output", output, label=f"{self.label}.output", attributes={"auto": True} ) # Link the output to the inputs for input_object_hash in input_object_hashes: - span.run.link_objects(output_object_hash, input_object_hash) + run.link_objects(output_object_hash, input_object_hash) for scorer in self.scorers: metric = await scorer(output) span.log_metric(scorer.name, metric, origin=output) # Trigger a run update whenever a task completes - run.push_update() + if run is not None: + run.push_update() return span ``` @@ -730,11 +748,14 @@ with_( name: str | None = None, tags: Sequence[str] | None = None, label: str | None = None, - log_inputs: Sequence[str] | bool | None = None, - log_output: bool | None = None, + log_inputs: Sequence[str] + | bool + | Inherited + | None = None, + log_output: bool | Inherited | None = None, log_execution_metrics: bool | None = None, append: bool = False, - **attributes: Any, + attributes: AnyDict | None = None, ) -> Task[P, R] ``` @@ -763,12 +784,12 @@ Clone a task and modify its attributes. ) –The new label for the task. * **`log_inputs`** - (`Sequence[str] | bool | None`, default: + (`Sequence[str] | bool | Inherited | None`, default: `None` ) –Log all, or specific, incoming arguments to the function as inputs. * **`log_output`** - (`bool | None`, default: + (`bool | Inherited | None`, default: `None` ) –Log the result of the function as an output. @@ -782,9 +803,9 @@ Clone a task and modify its attributes. `False` ) –If True, appends the new scorers and tags to the existing ones. If False, replaces them. -* **`**attributes`** - (`Any`, default: - `{}` +* **`attributes`** + (`AnyDict | None`, default: + `None` ) –Additional attributes to set or update in the task. @@ -802,11 +823,11 @@ def with_( name: str | None = None, tags: t.Sequence[str] | None = None, label: str | None = None, - log_inputs: t.Sequence[str] | bool | None = None, - log_output: bool | None = None, + log_inputs: t.Sequence[str] | bool | Inherited | None = None, + log_output: bool | Inherited | None = None, log_execution_metrics: bool | None = None, append: bool = False, - **attributes: t.Any, + attributes: AnyDict | None = None, ) -> "Task[P, R]": """ Clone a task and modify its attributes. @@ -820,7 +841,7 @@ def with_( log_output: Log the result of the function as an output. log_execution_metrics: Log execution metrics such as success rate and run count. append: If True, appends the new scorers and tags to the existing ones. If False, replaces them. - **attributes: Additional attributes to set or update in the task. + attributes: Additional attributes to set or update in the task. Returns: A new Task instance with the modified attributes. @@ -842,11 +863,11 @@ def with_( if append: task.scorers.extend(new_scorers) task.tags.extend(new_tags) - task.attributes.update(attributes) + task.attributes.update(attributes or {}) else: task.scorers = new_scorers task.tags = new_tags - task.attributes = attributes + task.attributes = attributes or {} return task ``` diff --git a/dreadnode/main.py b/dreadnode/main.py index a31dc88b..96d0600a 100644 --- a/dreadnode/main.py +++ b/dreadnode/main.py @@ -59,7 +59,6 @@ INHERITED, AnyDict, Inherited, - JsonDict, JsonValue, ) from dreadnode.util import clean_str, handle_internal_errors, logger @@ -445,7 +444,7 @@ def span( name: str, *, tags: t.Sequence[str] | None = None, - **attributes: t.Any, + attributes: AnyDict | None = None, ) -> Span: """ Create a new OpenTelemety span. @@ -464,7 +463,7 @@ def span( Args: name: The name of the span. tags: A list of tags to attach to the span. - **attributes: A dictionary of attributes to attach to the span. + attributes: A dictionary of attributes to attach to the span. Returns: A Span object. @@ -528,7 +527,7 @@ def task( log_output: bool | Inherited = INHERITED, log_execution_metrics: bool = False, tags: t.Sequence[str] | None = None, - **attributes: t.Any, + attributes: AnyDict | None = None, ) -> TaskDecorator: ... @t.overload @@ -542,7 +541,7 @@ def task( log_output: bool | Inherited = INHERITED, log_execution_metrics: bool = False, tags: t.Sequence[str] | None = None, - **attributes: t.Any, + attributes: AnyDict | None = None, ) -> ScoredTaskDecorator[R]: ... def task( @@ -555,7 +554,7 @@ def task( log_output: bool | Inherited = INHERITED, log_execution_metrics: bool = False, tags: t.Sequence[str] | None = None, - **attributes: t.Any, + attributes: AnyDict | None = None, ) -> TaskDecorator: """ Create a new task from a function. @@ -578,7 +577,7 @@ async def my_task(x: int) -> int: log_output: Log the result of the function as an output. log_execution_metrics: Log execution metrics for the task, such as success rate and run count. tags: A list of tags to attach to the task span. - **attributes: A dictionary of attributes to attach to the task span. + attributes: A dictionary of attributes to attach to the task span. Returns: A new Task object. @@ -587,6 +586,18 @@ async def my_task(x: int) -> int: def make_task( func: t.Callable[P, t.Awaitable[R]] | t.Callable[P, R], ) -> Task[P, R]: + if isinstance(func, Task): + return func.with_( + name=name, + label=label, + log_inputs=log_inputs, + log_output=log_output, + log_execution_metrics=log_execution_metrics, + tags=tags, + attributes=attributes, + append=True, + ) + unwrapped = inspect.unwrap(func) if inspect.isgeneratorfunction(unwrapped) or inspect.isasyncgenfunction( @@ -643,7 +654,7 @@ def task_span( *, label: str | None = None, tags: t.Sequence[str] | None = None, - **attributes: t.Any, + attributes: AnyDict | None = None, ) -> TaskSpan[t.Any]: """ Create a task span without an explicit associated function. @@ -661,21 +672,20 @@ def task_span( name: The name of the task. label: The label of the task - useful for filtering in the UI. tags: A list of tags to attach to the task span. - **attributes: A dictionary of attributes to attach to the task span. + attributes: A dictionary of attributes to attach to the task span. Returns: A TaskSpan object. """ - if (run := current_run_span.get()) is None: - raise RuntimeError("Task spans must be created within a run") + run = current_run_span.get() + label = clean_str(label or name) - label = label or clean_str(name) return TaskSpan( name=name, label=label, attributes=attributes, tags=tags, - run_id=run.run_id, + run_id=run.run_id if run else "", tracer=self._get_tracer(), ) @@ -684,7 +694,7 @@ def scorer( *, name: str | None = None, tags: t.Sequence[str] | None = None, - **attributes: t.Any, + attributes: AnyDict | None = None, ) -> t.Callable[[ScorerCallable[T]], Scorer[T]]: """ Make a scorer from a callable function. @@ -708,7 +718,7 @@ async def my_task(x: int) -> int: Args: name: The name of the scorer. tags: A list of tags to attach to the scorer. - **attributes: A dictionary of attributes to attach to the scorer. + attributes: A dictionary of attributes to attach to the scorer. Returns: A new Scorer object. @@ -733,7 +743,7 @@ def run( params: AnyDict | None = None, project: str | None = None, autolog: bool = True, - **attributes: t.Any, + attributes: AnyDict | None = None, ) -> RunSpan: """ Create a new run. @@ -759,7 +769,7 @@ def run( the project passed to `configure()` will be used, or the run will be associated with a default project. autolog: Whether to automatically log task inputs, outputs, and execution metrics if otherwise unspecified. - **attributes: Additional attributes to attach to the run span. + attributes: Additional attributes to attach to the run span. Returns: A RunSpan object that can be used as a context manager. @@ -937,7 +947,7 @@ def log_metric( origin: t.Any | None = None, timestamp: datetime | None = None, mode: MetricAggMode | None = None, - attributes: JsonDict | None = None, + attributes: AnyDict | None = None, to: ToObject = "task-or-run", ) -> Metric: """ @@ -1028,7 +1038,7 @@ def log_metric( origin: t.Any | None = None, timestamp: datetime | None = None, mode: MetricAggMode | None = None, - attributes: JsonDict | None = None, + attributes: AnyDict | None = None, to: ToObject = "task-or-run", ) -> Metric: """ @@ -1102,7 +1112,7 @@ def log_metrics( step: int = 0, timestamp: datetime | None = None, mode: MetricAggMode | None = None, - attributes: JsonDict | None = None, + attributes: AnyDict | None = None, to: ToObject = "task-or-run", ) -> list[Metric]: """ @@ -1142,7 +1152,7 @@ def log_metrics( step: int = 0, timestamp: datetime | None = None, mode: MetricAggMode | None = None, - attributes: JsonDict | None = None, + attributes: AnyDict | None = None, to: ToObject = "task-or-run", ) -> list[Metric]: """ @@ -1181,7 +1191,7 @@ def log_metrics( step: int = 0, timestamp: datetime | None = None, mode: MetricAggMode | None = None, - attributes: JsonDict | None = None, + attributes: AnyDict | None = None, to: ToObject = "task-or-run", ) -> list[Metric]: """ @@ -1314,7 +1324,7 @@ def log_input( *, label: str | None = None, to: ToObject = "task-or-run", - **attributes: t.Any, + attributes: AnyDict | None = None, ) -> None: """ Log a single input to the current task or run. @@ -1342,13 +1352,13 @@ async def my_task(x: int) -> int: if target is None: raise RuntimeError("log_inputs() must be called within a run") - target.log_input(name, value, label=label, **attributes) + target.log_input(name, value, label=label, attributes=attributes) @handle_internal_errors() def log_inputs( self, to: ToObject = "task-or-run", - **inputs: JsonValue, + **inputs: t.Any, ) -> None: """ Log multiple inputs to the current task or run. @@ -1366,7 +1376,7 @@ def log_output( *, label: str | None = None, to: ToObject = "task-or-run", - **attributes: JsonValue, + attributes: AnyDict | None = None, ) -> None: """ Log a single output to the current task or run. @@ -1387,6 +1397,15 @@ async def my_task(x: int) -> int: dreadnode.log_output("other", 123) ``` + + Args: + name: The name of the output. + value: The value of the output. + label: An optional label for the output, useful for filtering in the UI. + to: The target object to log the output to. Can be "task-or-run" or "run". + Defaults to "task-or-run". If "task-or-run", the output will be logged + to the current task or run, whichever is the nearest ancestor. + attributes: Additional attributes to attach to the output. """ task = current_task_span.get() run = current_run_span.get() @@ -1397,7 +1416,7 @@ async def my_task(x: int) -> int: "log_output() must be called within a run or a task", ) - target.log_output(name, value, label=label, **attributes) + target.log_output(name, value, label=label, attributes=attributes) @handle_internal_errors() def log_outputs( @@ -1414,7 +1433,12 @@ def log_outputs( self.log_output(name, value, to=to) @handle_internal_errors() - def link_objects(self, origin: t.Any, link: t.Any, **attributes: JsonValue) -> None: + def link_objects( + self, + origin: t.Any, + link: t.Any, + attributes: AnyDict | None = None, + ) -> None: """ Associate two runtime objects with each other. @@ -1434,14 +1458,14 @@ def link_objects(self, origin: t.Any, link: t.Any, **attributes: JsonValue) -> N Args: origin: The origin object to link from. link: The linked object to link to. - **attributes: Additional attributes to attach to the link. + attributes: Additional attributes to attach to the link. """ if (run := current_run_span.get()) is None: raise RuntimeError("link() must be called within a run") origin_hash = run.log_object(origin) link_hash = run.log_object(link) - run.link_objects(origin_hash, link_hash, **attributes) + run.link_objects(origin_hash, link_hash, attributes=attributes) DEFAULT_INSTANCE = Dreadnode() diff --git a/dreadnode/object.py b/dreadnode/object.py index 26e7aab9..28dbe589 100644 --- a/dreadnode/object.py +++ b/dreadnode/object.py @@ -1,7 +1,7 @@ import typing as t from dataclasses import dataclass -from dreadnode.types import JsonDict +from dreadnode.types import AnyDict @dataclass @@ -9,7 +9,7 @@ class ObjectRef: name: str label: str hash: str - attributes: JsonDict + attributes: AnyDict | None @dataclass diff --git a/dreadnode/task.py b/dreadnode/task.py index 7c381e3e..37bd9b0e 100644 --- a/dreadnode/task.py +++ b/dreadnode/task.py @@ -9,8 +9,8 @@ from dreadnode.metric import Scorer, ScorerCallable from dreadnode.serialization import seems_useful_to_serialize -from dreadnode.tracing.span import Span, TaskSpan, current_run_span -from dreadnode.types import INHERITED, Inherited +from dreadnode.tracing.span import TaskSpan, current_run_span +from dreadnode.types import INHERITED, AnyDict, Inherited P = t.ParamSpec("P") R = t.TypeVar("R") @@ -180,11 +180,11 @@ def with_( name: str | None = None, tags: t.Sequence[str] | None = None, label: str | None = None, - log_inputs: t.Sequence[str] | bool | None = None, - log_output: bool | None = None, + log_inputs: t.Sequence[str] | bool | Inherited | None = None, + log_output: bool | Inherited | None = None, log_execution_metrics: bool | None = None, append: bool = False, - **attributes: t.Any, + attributes: AnyDict | None = None, ) -> "Task[P, R]": """ Clone a task and modify its attributes. @@ -198,7 +198,7 @@ def with_( log_output: Log the result of the function as an output. log_execution_metrics: Log execution metrics such as success rate and run count. append: If True, appends the new scorers and tags to the existing ones. If False, replaces them. - **attributes: Additional attributes to set or update in the task. + attributes: Additional attributes to set or update in the task. Returns: A new Task instance with the modified attributes. @@ -220,11 +220,11 @@ def with_( if append: task.scorers.extend(new_scorers) task.tags.extend(new_tags) - task.attributes.update(attributes) + task.attributes.update(attributes or {}) else: task.scorers = new_scorers task.tags = new_tags - task.attributes = attributes + task.attributes = attributes or {} return task @@ -240,11 +240,18 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: The span associated with task execution. """ - if (run := current_run_span.get()) is None: - raise RuntimeError("Tasks must be executed within a run") + run = current_run_span.get() - log_inputs = run.autolog if isinstance(self.log_inputs, Inherited) else self.log_inputs - log_output = run.autolog if isinstance(self.log_output, Inherited) else self.log_output + log_inputs = ( + (run.autolog if run else False) + if isinstance(self.log_inputs, Inherited) + else self.log_inputs + ) + log_output = ( + (run.autolog if run else False) + if isinstance(self.log_output, Inherited) + else self.log_output + ) bound_args = self._bind_args(*args, **kwargs) @@ -266,16 +273,22 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: label=self.label, attributes=self.attributes, tags=self.tags, - run_id=run.run_id, + run_id=run.run_id if run else "", tracer=self.tracer, ) as span: - if self.log_execution_metrics: - span.run.log_metric( - "count", 1, prefix=f"{self.label}.exec", mode="count", attributes={"auto": True} + if run and self.log_execution_metrics: + run.log_metric( + "count", + 1, + prefix=f"{self.label}.exec", + mode="count", + attributes={"auto": True}, ) input_object_hashes: list[str] = [ - span.log_input(name, value, label=f"{self.label}.input.{name}", auto=True) + span.log_input( + name, value, label=f"{self.label}.input.{name}", attributes={"auto": True} + ) for name, value in inputs_to_log.items() ] @@ -284,8 +297,8 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: if inspect.isawaitable(output): output = await output except Exception: - if self.log_execution_metrics: - span.run.log_metric( + if run and self.log_execution_metrics: + run.log_metric( "success_rate", 0, prefix=f"{self.label}.exec", @@ -294,8 +307,8 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: ) raise - if self.log_execution_metrics: - span.run.log_metric( + if run and self.log_execution_metrics: + run.log_metric( "success_rate", 1, prefix=f"{self.label}.exec", @@ -304,40 +317,32 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: ) span.output = output - if log_output and ( - not isinstance(self.log_inputs, Inherited) or seems_useful_to_serialize(output) + if ( + run + and log_output + and ( + not isinstance(self.log_inputs, Inherited) or seems_useful_to_serialize(output) + ) ): output_object_hash = span.log_output( - "output", output, label=f"{self.label}.output", auto=True + "output", output, label=f"{self.label}.output", attributes={"auto": True} ) # Link the output to the inputs for input_object_hash in input_object_hashes: - span.run.link_objects(output_object_hash, input_object_hash) + run.link_objects(output_object_hash, input_object_hash) for scorer in self.scorers: metric = await scorer(output) span.log_metric(scorer.name, metric, origin=output) # Trigger a run update whenever a task completes - run.push_update() + if run is not None: + run.push_update() return span async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: - if not current_run_span.get(): - with Span( - self.name, - self.attributes, - self.tracer, - label=self.label, - tags=self.tags, - ): - result = self.func(*args, **kwargs) - if inspect.isawaitable(result): - result = await result - return result - span = await self.run(*args, **kwargs) return span.output diff --git a/dreadnode/tracing/span.py b/dreadnode/tracing/span.py index 5b9b04b2..54898b16 100644 --- a/dreadnode/tracing/span.py +++ b/dreadnode/tracing/span.py @@ -36,7 +36,7 @@ from dreadnode.metric import Metric, MetricAggMode, MetricsDict from dreadnode.object import Object, ObjectRef, ObjectUri, ObjectVal from dreadnode.serialization import Serialized, serialize -from dreadnode.types import UNSET, AnyDict, JsonDict, JsonValue, Unset +from dreadnode.types import UNSET, AnyDict, JsonDict, Unset from dreadnode.util import clean_str from dreadnode.version import VERSION @@ -101,9 +101,9 @@ class Span(ReadableSpan): def __init__( self, name: str, - attributes: AnyDict, tracer: Tracer, *, + attributes: AnyDict | None = None, label: str | None = None, type: SpanType = "span", tags: t.Sequence[str] | None = None, @@ -120,7 +120,7 @@ def __init__( SPAN_ATTRIBUTE_TYPE: type, SPAN_ATTRIBUTE_LABEL: self._label, SPAN_ATTRIBUTE_TAGS_: self.tags, - **attributes, + **(attributes or {}), } self._tracer = tracer @@ -264,6 +264,25 @@ def log_event( attributes=prepare_otlp_attributes(attributes or {}), ) + def set_exception( + self, + exception: BaseException, + *, + attributes: AnyDict | None = None, + status: Status | None = None, + ) -> None: + if self._span is None: + raise ValueError("Span is not active") + + if status is None: + status = Status(StatusCode.ERROR, str(exception)) + + self._span.set_status(status) + self._span.record_exception( + exception, + attributes=prepare_otlp_attributes(attributes or {}), + ) + def __repr__(self) -> str: return ( f"{self.__class__.__name__}(name='{self._span_name}', id={self.span_id}," @@ -317,7 +336,7 @@ def __init__( large_attrs.append(SPAN_ATTRIBUTE_OBJECT_SCHEMAS) attributes[SPAN_ATTRIBUTE_LARGE_ATTRIBUTES] = large_attrs - super().__init__(f"run.{run_id}.update", attributes, tracer, type="run_update") + super().__init__(f"run.{run_id}.update", tracer, type="run_update", attributes=attributes) def __repr__(self) -> str: status = "active" if self.is_recording else "inactive" @@ -335,11 +354,11 @@ def __init__( self, name: str, project: str, - attributes: AnyDict, tracer: Tracer, file_system: AbstractFileSystem, prefix_path: str, *, + attributes: AnyDict | None = None, params: AnyDict | None = None, metrics: MetricsDict | None = None, tags: t.Sequence[str] | None = None, @@ -386,9 +405,9 @@ def __init__( attributes = { SPAN_ATTRIBUTE_RUN_ID: str(run_id or ULID()), SPAN_ATTRIBUTE_PROJECT: project, - **attributes, + **(attributes or {}), } - super().__init__(name, attributes, tracer, type=type, tags=tags) + super().__init__(name, tracer, attributes=attributes, type=type, tags=tags) @classmethod def from_context( @@ -545,7 +564,7 @@ def log_object( *, label: str | None = None, event_name: str = EVENT_NAME_OBJECT, - **attributes: JsonValue, + attributes: AnyDict | None = None, ) -> str: serialized = serialize(value) data_hash = serialized.data_hash @@ -571,7 +590,7 @@ def log_object( # Build event attributes, use composite hash in events event_attributes = { - **attributes, + **(attributes or {}), EVENT_ATTRIBUTE_OBJECT_HASH: composite_hash, EVENT_ATTRIBUTE_ORIGIN_SPAN_ID: trace_api.format_span_id( trace_api.get_current_span().get_span_context().span_id, @@ -638,12 +657,12 @@ def link_objects( self, object_hash: str, link_hash: str, - **attributes: JsonValue, + attributes: AnyDict | None = None, ) -> None: self.log_event( name=EVENT_NAME_OBJECT_LINK, attributes={ - **attributes, + **(attributes or {}), EVENT_ATTRIBUTE_OBJECT_HASH: object_hash, EVENT_ATTRIBUTE_LINK_HASH: link_hash, EVENT_ATTRIBUTE_ORIGIN_SPAN_ID: ( @@ -679,9 +698,9 @@ def log_input( value: t.Any, *, label: str | None = None, - **attributes: JsonValue, + attributes: AnyDict | None = None, ) -> None: - label = label or clean_str(name) + label = clean_str(label or name) hash_ = self.log_object( value, label=label, @@ -757,7 +776,10 @@ def log_metric( value if isinstance(value, Metric) else Metric( - float(value), step, timestamp or datetime.now(timezone.utc), attributes or {} + float(value), + step, + timestamp or datetime.now(timezone.utc), + attributes or {}, ) ) @@ -791,9 +813,9 @@ def log_output( value: t.Any, *, label: str | None = None, - **attributes: JsonValue, + attributes: AnyDict | None = None, ) -> None: - label = label or clean_str(name) + label = clean_str(label or name) hash_ = self.log_object( value, label=label, @@ -827,10 +849,10 @@ class TaskSpan(Span, t.Generic[R]): def __init__( self, name: str, - attributes: AnyDict, run_id: str, tracer: Tracer, *, + attributes: AnyDict | None = None, label: str | None = None, metrics: MetricsDict | None = None, tags: t.Sequence[str] | None = None, @@ -851,9 +873,9 @@ def __init__( SPAN_ATTRIBUTE_INPUTS: self._inputs, SPAN_ATTRIBUTE_METRICS: self._metrics, SPAN_ATTRIBUTE_OUTPUTS: self._outputs, - **attributes, + **(attributes or {}), } - super().__init__(name, attributes, tracer, type="task", label=label, tags=tags) + super().__init__(name, tracer, type="task", attributes=attributes, label=label, tags=tags) def __enter__(self) -> te.Self: self._run = current_run_span.get() @@ -921,7 +943,9 @@ def run(self) -> RunSpan: @property def outputs(self) -> AnyDict: - return {ref.name: self.run.get_object(ref.hash) for ref in self._outputs} + if self._run is None: + return {} + return {ref.name: self._run.get_object(ref.hash) for ref in self._outputs} @property def output(self) -> R: @@ -939,10 +963,16 @@ def log_output( value: t.Any, *, label: str | None = None, - **attributes: JsonValue, + attributes: AnyDict | None = None, ) -> str: - label = label or clean_str(name) - hash_ = self.run.log_object( + label = clean_str(label or name) + + if self._run is None: + serialized = serialize(value) + self.set_attribute(label, serialized.data, schema=False) + return serialized.data_hash + + hash_ = self._run.log_object( value, label=label, event_name=EVENT_NAME_OBJECT_OUTPUT, @@ -952,7 +982,9 @@ def log_output( @property def inputs(self) -> AnyDict: - return {ref.name: self.run.get_object(ref.hash) for ref in self._inputs} + if self._run is None: + return {} + return {ref.name: self._run.get_object(ref.hash) for ref in self._inputs} def log_input( self, @@ -960,10 +992,16 @@ def log_input( value: t.Any, *, label: str | None = None, - **attributes: JsonValue, + attributes: AnyDict | None = None, ) -> str: - label = label or clean_str(name) - hash_ = self.run.log_object( + label = clean_str(label or name) + + if self._run is None: + serialized = serialize(value) + self.set_attribute(label, serialized.data, schema=False) + return serialized.data_hash + + hash_ = self._run.log_object( value, label=label, event_name=EVENT_NAME_OBJECT_INPUT, @@ -1013,7 +1051,10 @@ def log_metric( value if isinstance(value, Metric) else Metric( - float(value), step, timestamp or datetime.now(timezone.utc), attributes or {} + float(value), + step, + timestamp or datetime.now(timezone.utc), + attributes or {}, ) ) diff --git a/dreadnode/util.py b/dreadnode/util.py index 9380c391..89262d23 100644 --- a/dreadnode/util.py +++ b/dreadnode/util.py @@ -32,7 +32,7 @@ def clean_str(s: str) -> str: """ Clean a string by replacing all non-alphanumeric characters (except `/` and `@`) with underscores. """ - return re.sub(r"[^\w/@]+", "_", s.lower()) + return re.sub(r"[^\w/@]+", "_", s.lower()).strip("_") def safe_repr(obj: t.Any) -> str: