From baad34dd73a0f5eb49f0a95afffd78990d30e0ee Mon Sep 17 00:00:00 2001 From: monoxgas Date: Sat, 12 Jul 2025 16:59:16 -0600 Subject: [PATCH 1/4] Add the ability for tasks to be executed outside of runs --- docs/sdk/main.mdx | 15 +++++--- docs/sdk/task.mdx | 49 +++++++++++++++-------- dreadnode/main.py | 19 +++++---- dreadnode/task.py | 65 +++++++++++++++++-------------- dreadnode/tracing/span.py | 81 +++++++++++++++++++++++++++------------ 5 files changed, 148 insertions(+), 81 deletions(-) diff --git a/docs/sdk/main.mdx b/docs/sdk/main.mdx index d382da95..538be1f9 100644 --- a/docs/sdk/main.mdx +++ b/docs/sdk/main.mdx @@ -248,8 +248,12 @@ def configure( self._initialized = False - self.server = server or os.environ.get(ENV_SERVER_URL) or os.environ.get(ENV_SERVER) - self.token = token or os.environ.get(ENV_API_TOKEN) or os.environ.get(ENV_API_KEY) + self.server = ( + server or os.environ.get(ENV_SERVER_URL) or os.environ.get(ENV_SERVER) + ) + self.token = ( + token or os.environ.get(ENV_API_TOKEN) or os.environ.get(ENV_API_KEY) + ) if local_dir is False and ENV_LOCAL_DIR in os.environ: env_local_dir = os.environ.get(ENV_LOCAL_DIR) @@ -2192,17 +2196,16 @@ def task_span( Returns: A TaskSpan object. """ - if (run := current_run_span.get()) is None: - raise RuntimeError("Task spans must be created within a run") - + run = current_run_span.get() label = label or clean_str(name) + return TaskSpan( name=name, label=label, attributes=attributes, params=params, tags=tags, - run_id=run.run_id, + run_id=run.run_id if run else "", tracer=self._get_tracer(), ) ``` diff --git a/docs/sdk/task.mdx b/docs/sdk/task.mdx index 6b63cd29..60752d23 100644 --- a/docs/sdk/task.mdx +++ b/docs/sdk/task.mdx @@ -288,11 +288,14 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: The span associated with task execution. """ - if (run := current_run_span.get()) is None: - raise RuntimeError("Tasks must be executed within a run") + run = current_run_span.get() - log_inputs = run.autolog if isinstance(self.log_inputs, Inherited) else self.log_inputs - log_output = run.autolog if isinstance(self.log_output, Inherited) else self.log_output + log_inputs = ( + run.autolog if run and isinstance(self.log_inputs, Inherited) else self.log_inputs + ) + log_output = ( + run.autolog if run and isinstance(self.log_output, Inherited) else self.log_output + ) bound_args = self._bind_args(*args, **kwargs) @@ -311,18 +314,27 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: else {} ) + # If log_inputs is inherited, filter out items that don't seem useful + # to serialize like `None` or repr fallbacks. + if isinstance(self.log_inputs, Inherited): + inputs_to_log = {k: v for k, v in inputs_to_log.items() if seems_useful_to_serialize(v)} + with TaskSpan[R]( name=self.name, label=self.label, attributes=self.attributes, params=params_to_log, tags=self.tags, - run_id=run.run_id, + run_id=run.run_id if run else "", tracer=self.tracer, ) as span: - if run.autolog: - span.run.log_metric( - "count", 1, prefix=f"{self.label}.exec", mode="count", attributes={"auto": True} + if self.log_execution_metrics: + run.log_metric( + "count", + 1, + prefix=f"{self.label}.exec", + mode="count", + attributes={"auto": True}, ) for name, value in params_to_log.items(): @@ -338,8 +350,8 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: if inspect.isawaitable(output): output = await output except Exception: - if run.autolog: - span.run.log_metric( + if run and self.log_execution_metrics: + run.log_metric( "success_rate", 0, prefix=f"{self.label}.exec", @@ -348,8 +360,8 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: ) raise - if run.autolog: - span.run.log_metric( + if run and self.log_execution_metrics: + run.log_metric( "success_rate", 1, prefix=f"{self.label}.exec", @@ -358,21 +370,28 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: ) span.output = output - if log_output: + if ( + run + and log_output + and ( + not isinstance(self.log_inputs, Inherited) or seems_useful_to_serialize(output) + ) + ): output_object_hash = span.log_output( "output", output, label=f"{self.label}.output", auto=True ) # Link the output to the inputs for input_object_hash in input_object_hashes: - span.run.link_objects(output_object_hash, input_object_hash) + run.link_objects(output_object_hash, input_object_hash) for scorer in self.scorers: metric = await scorer(output) span.log_metric(scorer.name, metric, origin=output) # Trigger a run update whenever a task completes - run.push_update() + if run is not None: + run.push_update() return span ``` diff --git a/dreadnode/main.py b/dreadnode/main.py index 8575f508..94acb9b8 100644 --- a/dreadnode/main.py +++ b/dreadnode/main.py @@ -207,7 +207,9 @@ def _resolve_docker_service(original_endpoint: str, parsed: ParseResult) -> str: return str(endpoint) # If nothing works, return original and let it fail with a helpful error - raise RuntimeError(f"Failed to connect to the Dreadnode Artifact storage at {endpoint}.") + raise RuntimeError( + f"Failed to connect to the Dreadnode Artifact storage at {endpoint}." + ) @staticmethod def _test_connection(endpoint: str) -> bool: @@ -265,8 +267,12 @@ def configure( self._initialized = False - self.server = server or os.environ.get(ENV_SERVER_URL) or os.environ.get(ENV_SERVER) - self.token = token or os.environ.get(ENV_API_TOKEN) or os.environ.get(ENV_API_KEY) + self.server = ( + server or os.environ.get(ENV_SERVER_URL) or os.environ.get(ENV_SERVER) + ) + self.token = ( + token or os.environ.get(ENV_API_TOKEN) or os.environ.get(ENV_API_KEY) + ) if local_dir is False and ENV_LOCAL_DIR in os.environ: env_local_dir = os.environ.get(ENV_LOCAL_DIR) @@ -668,17 +674,16 @@ def task_span( Returns: A TaskSpan object. """ - if (run := current_run_span.get()) is None: - raise RuntimeError("Task spans must be created within a run") - + run = current_run_span.get() label = label or clean_str(name) + return TaskSpan( name=name, label=label, attributes=attributes, params=params, tags=tags, - run_id=run.run_id, + run_id=run.run_id if run else "", tracer=self._get_tracer(), ) diff --git a/dreadnode/task.py b/dreadnode/task.py index 8e23c908..29ba345e 100644 --- a/dreadnode/task.py +++ b/dreadnode/task.py @@ -8,7 +8,8 @@ from opentelemetry.trace import Tracer from dreadnode.metric import Scorer, ScorerCallable -from dreadnode.tracing.span import Span, TaskSpan, current_run_span +from dreadnode.serialization import seems_useful_to_serialize +from dreadnode.tracing.span import TaskSpan, current_run_span from dreadnode.types import INHERITED, Inherited P = t.ParamSpec("P") @@ -237,11 +238,14 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: The span associated with task execution. """ - if (run := current_run_span.get()) is None: - raise RuntimeError("Tasks must be executed within a run") + run = current_run_span.get() - log_inputs = run.autolog if isinstance(self.log_inputs, Inherited) else self.log_inputs - log_output = run.autolog if isinstance(self.log_output, Inherited) else self.log_output + log_inputs = ( + run.autolog if run and isinstance(self.log_inputs, Inherited) else self.log_inputs + ) + log_output = ( + run.autolog if run and isinstance(self.log_output, Inherited) else self.log_output + ) bound_args = self._bind_args(*args, **kwargs) @@ -260,18 +264,27 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: else {} ) + # If log_inputs is inherited, filter out items that don't seem useful + # to serialize like `None` or repr fallbacks. + if isinstance(self.log_inputs, Inherited): + inputs_to_log = {k: v for k, v in inputs_to_log.items() if seems_useful_to_serialize(v)} + with TaskSpan[R]( name=self.name, label=self.label, attributes=self.attributes, params=params_to_log, tags=self.tags, - run_id=run.run_id, + run_id=run.run_id if run else "", tracer=self.tracer, ) as span: - if run.autolog: - span.run.log_metric( - "count", 1, prefix=f"{self.label}.exec", mode="count", attributes={"auto": True} + if self.log_execution_metrics: + run.log_metric( + "count", + 1, + prefix=f"{self.label}.exec", + mode="count", + attributes={"auto": True}, ) for name, value in params_to_log.items(): @@ -287,8 +300,8 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: if inspect.isawaitable(output): output = await output except Exception: - if run.autolog: - span.run.log_metric( + if run and self.log_execution_metrics: + run.log_metric( "success_rate", 0, prefix=f"{self.label}.exec", @@ -297,8 +310,8 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: ) raise - if run.autolog: - span.run.log_metric( + if run and self.log_execution_metrics: + run.log_metric( "success_rate", 1, prefix=f"{self.label}.exec", @@ -307,38 +320,32 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: ) span.output = output - if log_output: + if ( + run + and log_output + and ( + not isinstance(self.log_inputs, Inherited) or seems_useful_to_serialize(output) + ) + ): output_object_hash = span.log_output( "output", output, label=f"{self.label}.output", auto=True ) # Link the output to the inputs for input_object_hash in input_object_hashes: - span.run.link_objects(output_object_hash, input_object_hash) + run.link_objects(output_object_hash, input_object_hash) for scorer in self.scorers: metric = await scorer(output) span.log_metric(scorer.name, metric, origin=output) # Trigger a run update whenever a task completes - run.push_update() + if run is not None: + run.push_update() return span async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: - if not current_run_span.get(): - with Span( - self.name, - self.attributes, - self.tracer, - label=self.label, - tags=self.tags, - ): - result = self.func(*args, **kwargs) - if inspect.isawaitable(result): - result = await result - return result - span = await self.run(*args, **kwargs) return span.output diff --git a/dreadnode/tracing/span.py b/dreadnode/tracing/span.py index 37770573..57b5ffde 100644 --- a/dreadnode/tracing/span.py +++ b/dreadnode/tracing/span.py @@ -198,7 +198,9 @@ def set_attribute( self._added_attributes = True if schema and raw is False: self._schema[key] = create_json_schema(value, set()) - otel_value = self._pre_attributes[key] = value if raw else prepare_otlp_attribute(value) + otel_value = self._pre_attributes[key] = ( + value if raw else prepare_otlp_attribute(value) + ) if self._span is not None: self._span.set_attribute(key, otel_value) self._pre_attributes[key] = otel_value @@ -258,7 +260,11 @@ def __init__( **({SPAN_ATTRIBUTE_INPUTS: inputs} if inputs else {}), **({SPAN_ATTRIBUTE_OUTPUTS: outputs} if outputs else {}), **({SPAN_ATTRIBUTE_OBJECTS: objects} if objects else {}), - **({SPAN_ATTRIBUTE_OBJECT_SCHEMAS: object_schemas} if object_schemas else {}), + **( + {SPAN_ATTRIBUTE_OBJECT_SCHEMAS: object_schemas} + if object_schemas + else {} + ), } # Mark objects and schemas as large attributes if present @@ -428,7 +434,9 @@ def push_update(self, *, force: bool = False) -> None: return current_time = time.time() - force_update = force or (current_time - self._last_update_time >= self._update_frequency) + force_update = force or ( + current_time - self._last_update_time >= self._update_frequency + ) should_update = force_update and ( self._pending_params or self._pending_inputs @@ -450,7 +458,9 @@ def push_update(self, *, force: bool = False) -> None: inputs=self._pending_inputs if self._pending_inputs else None, outputs=self._pending_outputs if self._pending_outputs else None, objects=self._pending_objects if self._pending_objects else None, - object_schemas=self._pending_object_schemas if self._pending_object_schemas else None, + object_schemas=self._pending_object_schemas + if self._pending_object_schemas + else None, ): pass @@ -531,7 +541,9 @@ def _store_file_by_hash(self, data: bytes, full_path: str) -> str: return str(self._file_system.unstrip_protocol(full_path)) - def _create_object_by_hash(self, serialized: Serialized, object_hash: str) -> Object: + def _create_object_by_hash( + self, serialized: Serialized, object_hash: str + ) -> Object: """Create an ObjectVal or ObjectUri depending on size with a specific hash.""" data = serialized.data data_bytes = serialized.data_bytes @@ -685,7 +697,10 @@ def log_metric( value if isinstance(value, Metric) else Metric( - float(value), step, timestamp or datetime.now(timezone.utc), attributes or {} + float(value), + step, + timestamp or datetime.now(timezone.utc), + attributes or {}, ) ) @@ -752,7 +767,9 @@ def __init__( self._output: R | Unset = UNSET # For the python output - self._context_token: Token[TaskSpan[t.Any] | None] | None = None # contextvars context + self._context_token: Token[TaskSpan[t.Any] | None] | None = ( + None # contextvars context + ) attributes = { SPAN_ATTRIBUTE_RUN_ID: str(run_id), @@ -770,9 +787,6 @@ def __enter__(self) -> te.Self: self.set_attribute(SPAN_ATTRIBUTE_PARENT_TASK_ID, self._parent_task.span_id) self._run = current_run_span.get() - if self._run is None: - raise RuntimeError("You cannot start a task span without a run") - self._context_token = current_task_span.set(self) return super().__enter__() @@ -798,15 +812,11 @@ def run_id(self) -> str: def parent_task_id(self) -> str: return str(self.get_attribute(SPAN_ATTRIBUTE_PARENT_TASK_ID, "")) - @property - def run(self) -> RunSpan: - if self._run is None: - raise ValueError("Task span is not in an active run") - return self._run - @property def outputs(self) -> AnyDict: - return {ref.name: self.run.get_object(ref.hash) for ref in self._outputs} + if self._run is None: + return {} + return {ref.name: self._run.get_object(ref.hash) for ref in self._outputs} @property def output(self) -> R: @@ -827,12 +837,20 @@ def log_output( **attributes: JsonValue, ) -> str: label = label or clean_str(name) - hash_ = self.run.log_object( + + if self._run is None: + serialized = serialize(value) + self.set_attribute(label, serialized.data, schema=False) + return serialized.data_hash + + hash_ = self._run.log_object( value, label=label, event_name=EVENT_NAME_OBJECT_OUTPUT, ) - self._outputs.append(ObjectRef(name, label=label, hash=hash_, attributes=attributes)) + self._outputs.append( + ObjectRef(name, label=label, hash=hash_, attributes=attributes) + ) return hash_ @property @@ -847,7 +865,9 @@ def log_params(self, **params: t.Any) -> None: @property def inputs(self) -> AnyDict: - return {ref.name: self.run.get_object(ref.hash) for ref in self._inputs} + if self._run is None: + return {} + return {ref.name: self._run.get_object(ref.hash) for ref in self._inputs} def log_input( self, @@ -858,12 +878,20 @@ def log_input( **attributes: JsonValue, ) -> str: label = label or clean_str(name) - hash_ = self.run.log_object( + + if self._run is None: + serialized = serialize(value) + self.set_attribute(label, serialized.data, schema=False) + return serialized.data_hash + + hash_ = self._run.log_object( value, label=label, event_name=EVENT_NAME_OBJECT_INPUT, ) - self._inputs.append(ObjectRef(name, label=label, hash=hash_, attributes=attributes)) + self._inputs.append( + ObjectRef(name, label=label, hash=hash_, attributes=attributes) + ) return hash_ @property @@ -908,7 +936,10 @@ def log_metric( value if isinstance(value, Metric) else Metric( - float(value), step, timestamp or datetime.now(timezone.utc), attributes or {} + float(value), + step, + timestamp or datetime.now(timezone.utc), + attributes or {}, ) ) @@ -922,7 +953,9 @@ def log_metric( # this task-metric was logged here. if (run := current_run_span.get()) is not None: - metric = run.log_metric(key, metric, prefix=self._label, origin=origin, mode=mode) + metric = run.log_metric( + key, metric, prefix=self._label, origin=origin, mode=mode + ) self._metrics.setdefault(key, []).append(metric) From 58a848481b2d2706fe5a02fa77b647fa0b8247d2 Mon Sep 17 00:00:00 2001 From: monoxgas Date: Wed, 16 Jul 2025 17:37:13 -0600 Subject: [PATCH 2/4] Finalizing updates and various arg cleanup --- docs/sdk/main.mdx | 159 +++++++++++++++++++++++--------------- dreadnode/main.py | 79 ++++++++++--------- dreadnode/tracing/span.py | 68 ++++++---------- dreadnode/util.py | 2 +- 4 files changed, 166 insertions(+), 142 deletions(-) diff --git a/docs/sdk/main.mdx b/docs/sdk/main.mdx index 26371ab0..5579dee3 100644 --- a/docs/sdk/main.mdx +++ b/docs/sdk/main.mdx @@ -248,12 +248,8 @@ def configure( self._initialized = False - self.server = ( - server or os.environ.get(ENV_SERVER_URL) or os.environ.get(ENV_SERVER) - ) - self.token = ( - token or os.environ.get(ENV_API_TOKEN) or os.environ.get(ENV_API_KEY) - ) + self.server = server or os.environ.get(ENV_SERVER_URL) or os.environ.get(ENV_SERVER) + self.token = token or os.environ.get(ENV_API_TOKEN) or os.environ.get(ENV_API_KEY) if local_dir is False and ENV_LOCAL_DIR in os.environ: env_local_dir = os.environ.get(ENV_LOCAL_DIR) @@ -496,7 +492,9 @@ def initialize(self) -> None: ```python link_objects( - origin: Any, link: Any, **attributes: JsonValue + origin: Any, + link: Any, + attributes: AnyDict | None = None, ) -> None ``` @@ -524,16 +522,21 @@ with dreadnode.run("my_run"): * **`link`** (`Any`) –The linked object to link to. -* **`**attributes`** - (`JsonValue`, default: - `{}` +* **`attributes`** + (`AnyDict | None`, default: + `None` ) –Additional attributes to attach to the link. ```python @handle_internal_errors() -def link_objects(self, origin: t.Any, link: t.Any, **attributes: JsonValue) -> None: +def link_objects( + self, + origin: t.Any, + link: t.Any, + attributes: AnyDict | None = None, +) -> None: """ Associate two runtime objects with each other. @@ -553,14 +556,14 @@ def link_objects(self, origin: t.Any, link: t.Any, **attributes: JsonValue) -> N Args: origin: The origin object to link from. link: The linked object to link to. - **attributes: Additional attributes to attach to the link. + attributes: Additional attributes to attach to the link. """ if (run := current_run_span.get()) is None: raise RuntimeError("link() must be called within a run") origin_hash = run.log_object(origin) link_hash = run.log_object(link) - run.link_objects(origin_hash, link_hash, **attributes) + run.link_objects(origin_hash, link_hash, attributes=attributes) ``` @@ -666,7 +669,7 @@ log_input( *, label: str | None = None, to: ToObject = "task-or-run", - **attributes: Any, + attributes: AnyDict | None = None, ) -> None ``` @@ -700,7 +703,7 @@ def log_input( *, label: str | None = None, to: ToObject = "task-or-run", - **attributes: t.Any, + attributes: AnyDict | None = None, ) -> None: """ Log a single input to the current task or run. @@ -728,7 +731,7 @@ def log_input( if target is None: raise RuntimeError("log_inputs() must be called within a run") - target.log_input(name, value, label=label, **attributes) + target.log_input(name, value, label=label, attributes=attributes) ``` @@ -738,7 +741,7 @@ def log_input( ```python log_inputs( - to: ToObject = "task-or-run", **inputs: JsonValue + to: ToObject = "task-or-run", **inputs: Any ) -> None ``` @@ -752,7 +755,7 @@ See `log_input()` for more details. def log_inputs( self, to: ToObject = "task-or-run", - **inputs: JsonValue, + **inputs: t.Any, ) -> None: """ Log multiple inputs to the current task or run. @@ -777,7 +780,7 @@ log_metric( origin: Any | None = None, timestamp: datetime | None = None, mode: MetricAggMode | None = None, - attributes: JsonDict | None = None, + attributes: AnyDict | None = None, to: ToObject = "task-or-run", ) -> Metric ``` @@ -802,7 +805,7 @@ log_metric( origin: Any | None = None, timestamp: datetime | None = None, mode: MetricAggMode | None = None, - attributes: JsonDict | None = None, + attributes: AnyDict | None = None, to: ToObject = "task-or-run", ) -> Metric ``` @@ -867,7 +870,7 @@ with dreadnode.run("my_run"): - sum: the cumulative sum of all reported values for this metric - count: increment every time this metric is logged - disregard value * **`attributes`** - (`JsonDict | None`, default: + (`AnyDict | None`, default: `None` ) –A dictionary of additional attributes to attach to the metric. @@ -896,7 +899,7 @@ def log_metric( origin: t.Any | None = None, timestamp: datetime | None = None, mode: MetricAggMode | None = None, - attributes: JsonDict | None = None, + attributes: AnyDict | None = None, to: ToObject = "task-or-run", ) -> Metric: """ @@ -975,7 +978,7 @@ log_metrics( step: int = 0, timestamp: datetime | None = None, mode: MetricAggMode | None = None, - attributes: JsonDict | None = None, + attributes: AnyDict | None = None, to: ToObject = "task-or-run", ) -> list[Metric] ``` @@ -987,7 +990,7 @@ log_metrics( step: int = 0, timestamp: datetime | None = None, mode: MetricAggMode | None = None, - attributes: JsonDict | None = None, + attributes: AnyDict | None = None, to: ToObject = "task-or-run", ) -> list[Metric] ``` @@ -999,7 +1002,7 @@ log_metrics( step: int = 0, timestamp: datetime | None = None, mode: MetricAggMode | None = None, - attributes: JsonDict | None = None, + attributes: AnyDict | None = None, to: ToObject = "task-or-run", ) -> list[Metric] ``` @@ -1054,7 +1057,7 @@ dreadnode.log_metrics( ) –Default aggregation mode for metrics if not supplied. * **`attributes`** - (`JsonDict | None`, default: + (`AnyDict | None`, default: `None` ) –Default attributes for metrics if not supplied. @@ -1081,7 +1084,7 @@ def log_metrics( step: int = 0, timestamp: datetime | None = None, mode: MetricAggMode | None = None, - attributes: JsonDict | None = None, + attributes: AnyDict | None = None, to: ToObject = "task-or-run", ) -> list[Metric]: """ @@ -1177,7 +1180,7 @@ log_output( *, label: str | None = None, to: ToObject = "task-or-run", - **attributes: JsonValue, + attributes: AnyDict | None = None, ) -> None ``` @@ -1201,6 +1204,31 @@ with dreadnode.run("my_run"): dreadnode.log_output("other", 123) ``` +**Parameters:** + +* **`name`** + (`str`) + –The name of the output. +* **`value`** + (`Any`) + –The value of the output. +* **`label`** + (`str | None`, default: + `None` + ) + –An optional label for the output, useful for filtering in the UI. +* **`to`** + (`ToObject`, default: + `'task-or-run'` + ) + –The target object to log the output to. Can be "task-or-run" or "run". + Defaults to "task-or-run". If "task-or-run", the output will be logged + to the current task or run, whichever is the nearest ancestor. +* **`attributes`** + (`AnyDict | None`, default: + `None` + ) + –Additional attributes to attach to the output. ```python @@ -1212,7 +1240,7 @@ def log_output( *, label: str | None = None, to: ToObject = "task-or-run", - **attributes: JsonValue, + attributes: AnyDict | None = None, ) -> None: """ Log a single output to the current task or run. @@ -1233,6 +1261,15 @@ def log_output( dreadnode.log_output("other", 123) ~~~ + + Args: + name: The name of the output. + value: The value of the output. + label: An optional label for the output, useful for filtering in the UI. + to: The target object to log the output to. Can be "task-or-run" or "run". + Defaults to "task-or-run". If "task-or-run", the output will be logged + to the current task or run, whichever is the nearest ancestor. + attributes: Additional attributes to attach to the output. """ task = current_task_span.get() run = current_run_span.get() @@ -1243,7 +1280,7 @@ def log_output( "log_output() must be called within a run or a task", ) - target.log_output(name, value, label=label, **attributes) + target.log_output(name, value, label=label, attributes=attributes) ``` @@ -1465,7 +1502,7 @@ run( params: AnyDict | None = None, project: str | None = None, autolog: bool = True, - **attributes: Any, + attributes: AnyDict | None = None, ) -> RunSpan ``` @@ -1514,9 +1551,9 @@ with dreadnode.run("my_run"): `True` ) –Whether to automatically log task inputs, outputs, and execution metrics if otherwise unspecified. -* **`**attributes`** - (`Any`, default: - `{}` +* **`attributes`** + (`AnyDict | None`, default: + `None` ) –Additional attributes to attach to the run span. @@ -1537,7 +1574,7 @@ def run( params: AnyDict | None = None, project: str | None = None, autolog: bool = True, - **attributes: t.Any, + attributes: AnyDict | None = None, ) -> RunSpan: """ Create a new run. @@ -1563,7 +1600,7 @@ def run( the project passed to `configure()` will be used, or the run will be associated with a default project. autolog: Whether to automatically log task inputs, outputs, and execution metrics if otherwise unspecified. - **attributes: Additional attributes to attach to the run span. + attributes: Additional attributes to attach to the run span. Returns: A RunSpan object that can be used as a context manager. @@ -1598,7 +1635,7 @@ scorer( *, name: str | None = None, tags: Sequence[str] | None = None, - **attributes: Any, + attributes: AnyDict | None = None, ) -> t.Callable[[ScorerCallable[T]], Scorer[T]] ``` @@ -1633,9 +1670,9 @@ await my_task(2) `None` ) –A list of tags to attach to the scorer. -* **`**attributes`** - (`Any`, default: - `{}` +* **`attributes`** + (`AnyDict | None`, default: + `None` ) –A dictionary of attributes to attach to the scorer. @@ -1651,7 +1688,7 @@ def scorer( *, name: str | None = None, tags: t.Sequence[str] | None = None, - **attributes: t.Any, + attributes: AnyDict | None = None, ) -> t.Callable[[ScorerCallable[T]], Scorer[T]]: """ Make a scorer from a callable function. @@ -1675,7 +1712,7 @@ def scorer( Args: name: The name of the scorer. tags: A list of tags to attach to the scorer. - **attributes: A dictionary of attributes to attach to the scorer. + attributes: A dictionary of attributes to attach to the scorer. Returns: A new Scorer object. @@ -1739,7 +1776,7 @@ span( name: str, *, tags: Sequence[str] | None = None, - **attributes: Any, + attributes: AnyDict | None = None, ) -> Span ``` @@ -1767,9 +1804,9 @@ with dreadnode.span("my_span") as span: `None` ) –A list of tags to attach to the span. -* **`**attributes`** - (`Any`, default: - `{}` +* **`attributes`** + (`AnyDict | None`, default: + `None` ) –A dictionary of attributes to attach to the span. @@ -1785,7 +1822,7 @@ def span( name: str, *, tags: t.Sequence[str] | None = None, - **attributes: t.Any, + attributes: AnyDict | None = None, ) -> Span: """ Create a new OpenTelemety span. @@ -1804,7 +1841,7 @@ def span( Args: name: The name of the span. tags: A list of tags to attach to the span. - **attributes: A dictionary of attributes to attach to the span. + attributes: A dictionary of attributes to attach to the span. Returns: A Span object. @@ -1895,7 +1932,7 @@ task( log_output: bool | Inherited = INHERITED, log_execution_metrics: bool = False, tags: Sequence[str] | None = None, - **attributes: Any, + attributes: AnyDict | None = None, ) -> TaskDecorator ``` @@ -1911,7 +1948,7 @@ task( log_output: bool | Inherited = INHERITED, log_execution_metrics: bool = False, tags: Sequence[str] | None = None, - **attributes: Any, + attributes: AnyDict | None = None, ) -> ScoredTaskDecorator[R] ``` @@ -1928,7 +1965,7 @@ task( log_output: bool | Inherited = INHERITED, log_execution_metrics: bool = False, tags: Sequence[str] | None = None, - **attributes: Any, + attributes: AnyDict | None = None, ) -> TaskDecorator ``` @@ -1982,9 +2019,9 @@ await my_task(2) `None` ) –A list of tags to attach to the task span. -* **`**attributes`** - (`Any`, default: - `{}` +* **`attributes`** + (`AnyDict | None`, default: + `None` ) –A dictionary of attributes to attach to the task span. @@ -2005,7 +2042,7 @@ def task( log_output: bool | Inherited = INHERITED, log_execution_metrics: bool = False, tags: t.Sequence[str] | None = None, - **attributes: t.Any, + attributes: AnyDict | None = None, ) -> TaskDecorator: """ Create a new task from a function. @@ -2028,7 +2065,7 @@ def task( log_output: Log the result of the function as an output. log_execution_metrics: Log execution metrics for the task, such as success rate and run count. tags: A list of tags to attach to the task span. - **attributes: A dictionary of attributes to attach to the task span. + attributes: A dictionary of attributes to attach to the task span. Returns: A new Task object. @@ -2099,7 +2136,7 @@ task_span( *, label: str | None = None, tags: Sequence[str] | None = None, - **attributes: Any, + attributes: AnyDict | None = None, ) -> TaskSpan[t.Any] ``` @@ -2120,7 +2157,7 @@ Args: name: The name of the task. label: The label of the task - useful for filtering in the UI. tags: A list of tags to attach to the task span. -\*\*attributes: A dictionary of attributes to attach to the task span. +attributes: A dictionary of attributes to attach to the task span. **Returns:** @@ -2135,7 +2172,7 @@ def task_span( *, label: str | None = None, tags: t.Sequence[str] | None = None, - **attributes: t.Any, + attributes: AnyDict | None = None, ) -> TaskSpan[t.Any]: """ Create a task span without an explicit associated function. @@ -2153,13 +2190,13 @@ def task_span( name: The name of the task. label: The label of the task - useful for filtering in the UI. tags: A list of tags to attach to the task span. - **attributes: A dictionary of attributes to attach to the task span. + attributes: A dictionary of attributes to attach to the task span. Returns: A TaskSpan object. """ run = current_run_span.get() - label = label or clean_str(name) + label = clean_str(label or name) return TaskSpan( name=name, diff --git a/dreadnode/main.py b/dreadnode/main.py index 5d5e428b..14323001 100644 --- a/dreadnode/main.py +++ b/dreadnode/main.py @@ -59,7 +59,6 @@ INHERITED, AnyDict, Inherited, - JsonDict, JsonValue, ) from dreadnode.util import clean_str, handle_internal_errors, logger @@ -207,9 +206,7 @@ def _resolve_docker_service(original_endpoint: str, parsed: ParseResult) -> str: return str(endpoint) # If nothing works, return original and let it fail with a helpful error - raise RuntimeError( - f"Failed to connect to the Dreadnode Artifact storage at {endpoint}." - ) + raise RuntimeError(f"Failed to connect to the Dreadnode Artifact storage at {endpoint}.") @staticmethod def _test_connection(endpoint: str) -> bool: @@ -267,12 +264,8 @@ def configure( self._initialized = False - self.server = ( - server or os.environ.get(ENV_SERVER_URL) or os.environ.get(ENV_SERVER) - ) - self.token = ( - token or os.environ.get(ENV_API_TOKEN) or os.environ.get(ENV_API_KEY) - ) + self.server = server or os.environ.get(ENV_SERVER_URL) or os.environ.get(ENV_SERVER) + self.token = token or os.environ.get(ENV_API_TOKEN) or os.environ.get(ENV_API_KEY) if local_dir is False and ENV_LOCAL_DIR in os.environ: env_local_dir = os.environ.get(ENV_LOCAL_DIR) @@ -451,7 +444,7 @@ def span( name: str, *, tags: t.Sequence[str] | None = None, - **attributes: t.Any, + attributes: AnyDict | None = None, ) -> Span: """ Create a new OpenTelemety span. @@ -470,7 +463,7 @@ def span( Args: name: The name of the span. tags: A list of tags to attach to the span. - **attributes: A dictionary of attributes to attach to the span. + attributes: A dictionary of attributes to attach to the span. Returns: A Span object. @@ -534,7 +527,7 @@ def task( log_output: bool | Inherited = INHERITED, log_execution_metrics: bool = False, tags: t.Sequence[str] | None = None, - **attributes: t.Any, + attributes: AnyDict | None = None, ) -> TaskDecorator: ... @t.overload @@ -548,7 +541,7 @@ def task( log_output: bool | Inherited = INHERITED, log_execution_metrics: bool = False, tags: t.Sequence[str] | None = None, - **attributes: t.Any, + attributes: AnyDict | None = None, ) -> ScoredTaskDecorator[R]: ... def task( @@ -561,7 +554,7 @@ def task( log_output: bool | Inherited = INHERITED, log_execution_metrics: bool = False, tags: t.Sequence[str] | None = None, - **attributes: t.Any, + attributes: AnyDict | None = None, ) -> TaskDecorator: """ Create a new task from a function. @@ -584,7 +577,7 @@ async def my_task(x: int) -> int: log_output: Log the result of the function as an output. log_execution_metrics: Log execution metrics for the task, such as success rate and run count. tags: A list of tags to attach to the task span. - **attributes: A dictionary of attributes to attach to the task span. + attributes: A dictionary of attributes to attach to the task span. Returns: A new Task object. @@ -649,7 +642,7 @@ def task_span( *, label: str | None = None, tags: t.Sequence[str] | None = None, - **attributes: t.Any, + attributes: AnyDict | None = None, ) -> TaskSpan[t.Any]: """ Create a task span without an explicit associated function. @@ -667,13 +660,13 @@ def task_span( name: The name of the task. label: The label of the task - useful for filtering in the UI. tags: A list of tags to attach to the task span. - **attributes: A dictionary of attributes to attach to the task span. + attributes: A dictionary of attributes to attach to the task span. Returns: A TaskSpan object. """ run = current_run_span.get() - label = label or clean_str(name) + label = clean_str(label or name) return TaskSpan( name=name, @@ -689,7 +682,7 @@ def scorer( *, name: str | None = None, tags: t.Sequence[str] | None = None, - **attributes: t.Any, + attributes: AnyDict | None = None, ) -> t.Callable[[ScorerCallable[T]], Scorer[T]]: """ Make a scorer from a callable function. @@ -713,7 +706,7 @@ async def my_task(x: int) -> int: Args: name: The name of the scorer. tags: A list of tags to attach to the scorer. - **attributes: A dictionary of attributes to attach to the scorer. + attributes: A dictionary of attributes to attach to the scorer. Returns: A new Scorer object. @@ -738,7 +731,7 @@ def run( params: AnyDict | None = None, project: str | None = None, autolog: bool = True, - **attributes: t.Any, + attributes: AnyDict | None = None, ) -> RunSpan: """ Create a new run. @@ -764,7 +757,7 @@ def run( the project passed to `configure()` will be used, or the run will be associated with a default project. autolog: Whether to automatically log task inputs, outputs, and execution metrics if otherwise unspecified. - **attributes: Additional attributes to attach to the run span. + attributes: Additional attributes to attach to the run span. Returns: A RunSpan object that can be used as a context manager. @@ -942,7 +935,7 @@ def log_metric( origin: t.Any | None = None, timestamp: datetime | None = None, mode: MetricAggMode | None = None, - attributes: JsonDict | None = None, + attributes: AnyDict | None = None, to: ToObject = "task-or-run", ) -> Metric: """ @@ -1033,7 +1026,7 @@ def log_metric( origin: t.Any | None = None, timestamp: datetime | None = None, mode: MetricAggMode | None = None, - attributes: JsonDict | None = None, + attributes: AnyDict | None = None, to: ToObject = "task-or-run", ) -> Metric: """ @@ -1107,7 +1100,7 @@ def log_metrics( step: int = 0, timestamp: datetime | None = None, mode: MetricAggMode | None = None, - attributes: JsonDict | None = None, + attributes: AnyDict | None = None, to: ToObject = "task-or-run", ) -> list[Metric]: """ @@ -1147,7 +1140,7 @@ def log_metrics( step: int = 0, timestamp: datetime | None = None, mode: MetricAggMode | None = None, - attributes: JsonDict | None = None, + attributes: AnyDict | None = None, to: ToObject = "task-or-run", ) -> list[Metric]: """ @@ -1186,7 +1179,7 @@ def log_metrics( step: int = 0, timestamp: datetime | None = None, mode: MetricAggMode | None = None, - attributes: JsonDict | None = None, + attributes: AnyDict | None = None, to: ToObject = "task-or-run", ) -> list[Metric]: """ @@ -1319,7 +1312,7 @@ def log_input( *, label: str | None = None, to: ToObject = "task-or-run", - **attributes: t.Any, + attributes: AnyDict | None = None, ) -> None: """ Log a single input to the current task or run. @@ -1347,13 +1340,13 @@ async def my_task(x: int) -> int: if target is None: raise RuntimeError("log_inputs() must be called within a run") - target.log_input(name, value, label=label, **attributes) + target.log_input(name, value, label=label, attributes=attributes) @handle_internal_errors() def log_inputs( self, to: ToObject = "task-or-run", - **inputs: JsonValue, + **inputs: t.Any, ) -> None: """ Log multiple inputs to the current task or run. @@ -1371,7 +1364,7 @@ def log_output( *, label: str | None = None, to: ToObject = "task-or-run", - **attributes: JsonValue, + attributes: AnyDict | None = None, ) -> None: """ Log a single output to the current task or run. @@ -1392,6 +1385,15 @@ async def my_task(x: int) -> int: dreadnode.log_output("other", 123) ``` + + Args: + name: The name of the output. + value: The value of the output. + label: An optional label for the output, useful for filtering in the UI. + to: The target object to log the output to. Can be "task-or-run" or "run". + Defaults to "task-or-run". If "task-or-run", the output will be logged + to the current task or run, whichever is the nearest ancestor. + attributes: Additional attributes to attach to the output. """ task = current_task_span.get() run = current_run_span.get() @@ -1402,7 +1404,7 @@ async def my_task(x: int) -> int: "log_output() must be called within a run or a task", ) - target.log_output(name, value, label=label, **attributes) + target.log_output(name, value, label=label, attributes=attributes) @handle_internal_errors() def log_outputs( @@ -1419,7 +1421,12 @@ def log_outputs( self.log_output(name, value, to=to) @handle_internal_errors() - def link_objects(self, origin: t.Any, link: t.Any, **attributes: JsonValue) -> None: + def link_objects( + self, + origin: t.Any, + link: t.Any, + attributes: AnyDict | None = None, + ) -> None: """ Associate two runtime objects with each other. @@ -1439,14 +1446,14 @@ def link_objects(self, origin: t.Any, link: t.Any, **attributes: JsonValue) -> N Args: origin: The origin object to link from. link: The linked object to link to. - **attributes: Additional attributes to attach to the link. + attributes: Additional attributes to attach to the link. """ if (run := current_run_span.get()) is None: raise RuntimeError("link() must be called within a run") origin_hash = run.log_object(origin) link_hash = run.log_object(link) - run.link_objects(origin_hash, link_hash, **attributes) + run.link_objects(origin_hash, link_hash, attributes=attributes) DEFAULT_INSTANCE = Dreadnode() diff --git a/dreadnode/tracing/span.py b/dreadnode/tracing/span.py index 8024a90d..ba43c9a9 100644 --- a/dreadnode/tracing/span.py +++ b/dreadnode/tracing/span.py @@ -87,9 +87,9 @@ class Span(ReadableSpan): def __init__( self, name: str, - attributes: AnyDict, tracer: Tracer, *, + attributes: AnyDict | None = None, label: str | None = None, type: SpanType = "span", tags: t.Sequence[str] | None = None, @@ -106,7 +106,7 @@ def __init__( SPAN_ATTRIBUTE_TYPE: type, SPAN_ATTRIBUTE_LABEL: self._label, SPAN_ATTRIBUTE_TAGS_: self.tags, - **attributes, + **(attributes or {}), } self._tracer = tracer @@ -198,9 +198,7 @@ def set_attribute( self._added_attributes = True if schema and raw is False: self._schema[key] = create_json_schema(value, set()) - otel_value = self._pre_attributes[key] = ( - value if raw else prepare_otlp_attribute(value) - ) + otel_value = self._pre_attributes[key] = value if raw else prepare_otlp_attribute(value) if self._span is not None: self._span.set_attribute(key, otel_value) self._pre_attributes[key] = otel_value @@ -260,11 +258,7 @@ def __init__( **({SPAN_ATTRIBUTE_INPUTS: inputs} if inputs else {}), **({SPAN_ATTRIBUTE_OUTPUTS: outputs} if outputs else {}), **({SPAN_ATTRIBUTE_OBJECTS: objects} if objects else {}), - **( - {SPAN_ATTRIBUTE_OBJECT_SCHEMAS: object_schemas} - if object_schemas - else {} - ), + **({SPAN_ATTRIBUTE_OBJECT_SCHEMAS: object_schemas} if object_schemas else {}), } # Mark objects and schemas as large attributes if present @@ -276,7 +270,7 @@ def __init__( large_attrs.append(SPAN_ATTRIBUTE_OBJECT_SCHEMAS) attributes[SPAN_ATTRIBUTE_LARGE_ATTRIBUTES] = large_attrs - super().__init__(f"run.{run_id}.update", attributes, tracer, type="run_update") + super().__init__(f"run.{run_id}.update", tracer, type="run_update", attributes=attributes) class RunSpan(Span): @@ -284,11 +278,11 @@ def __init__( self, name: str, project: str, - attributes: AnyDict, tracer: Tracer, file_system: AbstractFileSystem, prefix_path: str, *, + attributes: AnyDict | None = None, params: AnyDict | None = None, metrics: MetricsDict | None = None, tags: t.Sequence[str] | None = None, @@ -333,9 +327,9 @@ def __init__( attributes = { SPAN_ATTRIBUTE_RUN_ID: str(run_id or ULID()), SPAN_ATTRIBUTE_PROJECT: project, - **attributes, + **(attributes or {}), } - super().__init__(name, attributes, tracer, type=type, tags=tags) + super().__init__(name, tracer, attributes=attributes, type=type, tags=tags) @classmethod def from_context( @@ -434,9 +428,7 @@ def push_update(self, *, force: bool = False) -> None: return current_time = time.time() - force_update = force or ( - current_time - self._last_update_time >= self._update_frequency - ) + force_update = force or (current_time - self._last_update_time >= self._update_frequency) should_update = force_update and ( self._pending_params or self._pending_inputs @@ -458,9 +450,7 @@ def push_update(self, *, force: bool = False) -> None: inputs=self._pending_inputs if self._pending_inputs else None, outputs=self._pending_outputs if self._pending_outputs else None, objects=self._pending_objects if self._pending_objects else None, - object_schemas=self._pending_object_schemas - if self._pending_object_schemas - else None, + object_schemas=self._pending_object_schemas if self._pending_object_schemas else None, ): pass @@ -483,7 +473,7 @@ def log_object( *, label: str | None = None, event_name: str = EVENT_NAME_OBJECT, - **attributes: JsonValue, + attributes: AnyDict | None = None, ) -> str: serialized = serialize(value) data_hash = serialized.data_hash @@ -509,7 +499,7 @@ def log_object( # Build event attributes, use composite hash in events event_attributes = { - **attributes, + **(attributes or {}), EVENT_ATTRIBUTE_OBJECT_HASH: composite_hash, EVENT_ATTRIBUTE_ORIGIN_SPAN_ID: trace_api.format_span_id( trace_api.get_current_span().get_span_context().span_id, @@ -541,9 +531,7 @@ def _store_file_by_hash(self, data: bytes, full_path: str) -> str: return str(self._file_system.unstrip_protocol(full_path)) - def _create_object_by_hash( - self, serialized: Serialized, object_hash: str - ) -> Object: + def _create_object_by_hash(self, serialized: Serialized, object_hash: str) -> Object: """Create an ObjectVal or ObjectUri depending on size with a specific hash.""" data = serialized.data data_bytes = serialized.data_bytes @@ -621,7 +609,7 @@ def log_input( label: str | None = None, **attributes: JsonValue, ) -> None: - label = label or clean_str(name) + label = clean_str(label or name) hash_ = self.log_object( value, label=label, @@ -736,7 +724,7 @@ def log_output( label: str | None = None, **attributes: JsonValue, ) -> None: - label = label or clean_str(name) + label = clean_str(label or name) hash_ = self.log_object( value, label=label, @@ -751,10 +739,10 @@ class TaskSpan(Span, t.Generic[R]): def __init__( self, name: str, - attributes: AnyDict, run_id: str, tracer: Tracer, *, + attributes: AnyDict | None = None, label: str | None = None, metrics: MetricsDict | None = None, tags: t.Sequence[str] | None = None, @@ -765,18 +753,16 @@ def __init__( self._output: R | Unset = UNSET # For the python output - self._context_token: Token[TaskSpan[t.Any] | None] | None = ( - None # contextvars context - ) + self._context_token: Token[TaskSpan[t.Any] | None] | None = None # contextvars context attributes = { SPAN_ATTRIBUTE_RUN_ID: str(run_id), SPAN_ATTRIBUTE_INPUTS: self._inputs, SPAN_ATTRIBUTE_METRICS: self._metrics, SPAN_ATTRIBUTE_OUTPUTS: self._outputs, - **attributes, + **(attributes or {}), } - super().__init__(name, attributes, tracer, type="task", label=label, tags=tags) + super().__init__(name, tracer, type="task", attributes=attributes, label=label, tags=tags) def __enter__(self) -> te.Self: self._parent_task = current_task_span.get() @@ -832,7 +818,7 @@ def log_output( label: str | None = None, **attributes: JsonValue, ) -> str: - label = label or clean_str(name) + label = clean_str(label or name) if self._run is None: serialized = serialize(value) @@ -844,9 +830,7 @@ def log_output( label=label, event_name=EVENT_NAME_OBJECT_OUTPUT, ) - self._outputs.append( - ObjectRef(name, label=label, hash=hash_, attributes=attributes) - ) + self._outputs.append(ObjectRef(name, label=label, hash=hash_, attributes=attributes)) return hash_ @property @@ -863,7 +847,7 @@ def log_input( label: str | None = None, **attributes: JsonValue, ) -> str: - label = label or clean_str(name) + label = clean_str(label or name) if self._run is None: serialized = serialize(value) @@ -875,9 +859,7 @@ def log_input( label=label, event_name=EVENT_NAME_OBJECT_INPUT, ) - self._inputs.append( - ObjectRef(name, label=label, hash=hash_, attributes=attributes) - ) + self._inputs.append(ObjectRef(name, label=label, hash=hash_, attributes=attributes)) return hash_ @property @@ -939,9 +921,7 @@ def log_metric( # this task-metric was logged here. if (run := current_run_span.get()) is not None: - metric = run.log_metric( - key, metric, prefix=self._label, origin=origin, mode=mode - ) + metric = run.log_metric(key, metric, prefix=self._label, origin=origin, mode=mode) self._metrics.setdefault(key, []).append(metric) diff --git a/dreadnode/util.py b/dreadnode/util.py index 9380c391..89262d23 100644 --- a/dreadnode/util.py +++ b/dreadnode/util.py @@ -32,7 +32,7 @@ def clean_str(s: str) -> str: """ Clean a string by replacing all non-alphanumeric characters (except `/` and `@`) with underscores. """ - return re.sub(r"[^\w/@]+", "_", s.lower()) + return re.sub(r"[^\w/@]+", "_", s.lower()).strip("_") def safe_repr(obj: t.Any) -> str: From 001f5b899ed0cbc6617fec8483438be840fcbcb8 Mon Sep 17 00:00:00 2001 From: monoxgas Date: Wed, 16 Jul 2025 17:57:40 -0600 Subject: [PATCH 3/4] Fixing some bugs --- docs/sdk/task.mdx | 32 +++++++++++++++++++------------- dreadnode/object.py | 4 ++-- dreadnode/task.py | 26 ++++++++++++++++---------- dreadnode/tracing/span.py | 14 +++++++------- 4 files changed, 44 insertions(+), 32 deletions(-) diff --git a/docs/sdk/task.mdx b/docs/sdk/task.mdx index 812b930a..aaa4593f 100644 --- a/docs/sdk/task.mdx +++ b/docs/sdk/task.mdx @@ -290,10 +290,14 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: run = current_run_span.get() log_inputs = ( - run.autolog if run and isinstance(self.log_inputs, Inherited) else self.log_inputs + (run.autolog if run else False) + if isinstance(self.log_inputs, Inherited) + else self.log_inputs ) log_output = ( - run.autolog if run and isinstance(self.log_output, Inherited) else self.log_output + (run.autolog if run else False) + if isinstance(self.log_output, Inherited) + else self.log_output ) bound_args = self._bind_args(*args, **kwargs) @@ -319,7 +323,7 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: run_id=run.run_id if run else "", tracer=self.tracer, ) as span: - if self.log_execution_metrics: + if run and self.log_execution_metrics: run.log_metric( "count", 1, @@ -329,7 +333,9 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: ) input_object_hashes: list[str] = [ - span.log_input(name, value, label=f"{self.label}.input.{name}", auto=True) + span.log_input( + name, value, label=f"{self.label}.input.{name}", attributes={"auto": True} + ) for name, value in inputs_to_log.items() ] @@ -366,7 +372,7 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: ) ): output_object_hash = span.log_output( - "output", output, label=f"{self.label}.output", auto=True + "output", output, label=f"{self.label}.output", attributes={"auto": True} ) # Link the output to the inputs @@ -746,7 +752,7 @@ with_( log_output: bool | None = None, log_execution_metrics: bool | None = None, append: bool = False, - **attributes: Any, + attributes: AnyDict | None = None, ) -> Task[P, R] ``` @@ -794,9 +800,9 @@ Clone a task and modify its attributes. `False` ) –If True, appends the new scorers and tags to the existing ones. If False, replaces them. -* **`**attributes`** - (`Any`, default: - `{}` +* **`attributes`** + (`AnyDict | None`, default: + `None` ) –Additional attributes to set or update in the task. @@ -818,7 +824,7 @@ def with_( log_output: bool | None = None, log_execution_metrics: bool | None = None, append: bool = False, - **attributes: t.Any, + attributes: AnyDict | None = None, ) -> "Task[P, R]": """ Clone a task and modify its attributes. @@ -832,7 +838,7 @@ def with_( log_output: Log the result of the function as an output. log_execution_metrics: Log execution metrics such as success rate and run count. append: If True, appends the new scorers and tags to the existing ones. If False, replaces them. - **attributes: Additional attributes to set or update in the task. + attributes: Additional attributes to set or update in the task. Returns: A new Task instance with the modified attributes. @@ -854,11 +860,11 @@ def with_( if append: task.scorers.extend(new_scorers) task.tags.extend(new_tags) - task.attributes.update(attributes) + task.attributes.update(attributes or {}) else: task.scorers = new_scorers task.tags = new_tags - task.attributes = attributes + task.attributes = attributes or {} return task ``` diff --git a/dreadnode/object.py b/dreadnode/object.py index 26e7aab9..28dbe589 100644 --- a/dreadnode/object.py +++ b/dreadnode/object.py @@ -1,7 +1,7 @@ import typing as t from dataclasses import dataclass -from dreadnode.types import JsonDict +from dreadnode.types import AnyDict @dataclass @@ -9,7 +9,7 @@ class ObjectRef: name: str label: str hash: str - attributes: JsonDict + attributes: AnyDict | None @dataclass diff --git a/dreadnode/task.py b/dreadnode/task.py index e9975d4e..b11dd8cc 100644 --- a/dreadnode/task.py +++ b/dreadnode/task.py @@ -10,7 +10,7 @@ from dreadnode.metric import Scorer, ScorerCallable from dreadnode.serialization import seems_useful_to_serialize from dreadnode.tracing.span import TaskSpan, current_run_span -from dreadnode.types import INHERITED, Inherited +from dreadnode.types import INHERITED, AnyDict, Inherited P = t.ParamSpec("P") R = t.TypeVar("R") @@ -184,7 +184,7 @@ def with_( log_output: bool | None = None, log_execution_metrics: bool | None = None, append: bool = False, - **attributes: t.Any, + attributes: AnyDict | None = None, ) -> "Task[P, R]": """ Clone a task and modify its attributes. @@ -198,7 +198,7 @@ def with_( log_output: Log the result of the function as an output. log_execution_metrics: Log execution metrics such as success rate and run count. append: If True, appends the new scorers and tags to the existing ones. If False, replaces them. - **attributes: Additional attributes to set or update in the task. + attributes: Additional attributes to set or update in the task. Returns: A new Task instance with the modified attributes. @@ -220,11 +220,11 @@ def with_( if append: task.scorers.extend(new_scorers) task.tags.extend(new_tags) - task.attributes.update(attributes) + task.attributes.update(attributes or {}) else: task.scorers = new_scorers task.tags = new_tags - task.attributes = attributes + task.attributes = attributes or {} return task @@ -243,10 +243,14 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: run = current_run_span.get() log_inputs = ( - run.autolog if run and isinstance(self.log_inputs, Inherited) else self.log_inputs + (run.autolog if run else False) + if isinstance(self.log_inputs, Inherited) + else self.log_inputs ) log_output = ( - run.autolog if run and isinstance(self.log_output, Inherited) else self.log_output + (run.autolog if run else False) + if isinstance(self.log_output, Inherited) + else self.log_output ) bound_args = self._bind_args(*args, **kwargs) @@ -272,7 +276,7 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: run_id=run.run_id if run else "", tracer=self.tracer, ) as span: - if self.log_execution_metrics: + if run and self.log_execution_metrics: run.log_metric( "count", 1, @@ -282,7 +286,9 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: ) input_object_hashes: list[str] = [ - span.log_input(name, value, label=f"{self.label}.input.{name}", auto=True) + span.log_input( + name, value, label=f"{self.label}.input.{name}", attributes={"auto": True} + ) for name, value in inputs_to_log.items() ] @@ -319,7 +325,7 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]: ) ): output_object_hash = span.log_output( - "output", output, label=f"{self.label}.output", auto=True + "output", output, label=f"{self.label}.output", attributes={"auto": True} ) # Link the output to the inputs diff --git a/dreadnode/tracing/span.py b/dreadnode/tracing/span.py index ba43c9a9..09c84cf3 100644 --- a/dreadnode/tracing/span.py +++ b/dreadnode/tracing/span.py @@ -34,7 +34,7 @@ from dreadnode.metric import Metric, MetricAggMode, MetricsDict from dreadnode.object import Object, ObjectRef, ObjectUri, ObjectVal from dreadnode.serialization import Serialized, serialize -from dreadnode.types import UNSET, AnyDict, JsonDict, JsonValue, Unset +from dreadnode.types import UNSET, AnyDict, JsonDict, Unset from dreadnode.util import clean_str from dreadnode.version import VERSION @@ -566,12 +566,12 @@ def link_objects( self, object_hash: str, link_hash: str, - **attributes: JsonValue, + attributes: AnyDict | None = None, ) -> None: self.log_event( name=EVENT_NAME_OBJECT_LINK, attributes={ - **attributes, + **(attributes or {}), EVENT_ATTRIBUTE_OBJECT_HASH: object_hash, EVENT_ATTRIBUTE_LINK_HASH: link_hash, EVENT_ATTRIBUTE_ORIGIN_SPAN_ID: ( @@ -607,7 +607,7 @@ def log_input( value: t.Any, *, label: str | None = None, - **attributes: JsonValue, + attributes: AnyDict | None = None, ) -> None: label = clean_str(label or name) hash_ = self.log_object( @@ -722,7 +722,7 @@ def log_output( value: t.Any, *, label: str | None = None, - **attributes: JsonValue, + attributes: AnyDict | None = None, ) -> None: label = clean_str(label or name) hash_ = self.log_object( @@ -816,7 +816,7 @@ def log_output( value: t.Any, *, label: str | None = None, - **attributes: JsonValue, + attributes: AnyDict | None = None, ) -> str: label = clean_str(label or name) @@ -845,7 +845,7 @@ def log_input( value: t.Any, *, label: str | None = None, - **attributes: JsonValue, + attributes: AnyDict | None = None, ) -> str: label = clean_str(label or name) From de78a9c94a9ac50899b1608787102d3ad4b41d21 Mon Sep 17 00:00:00 2001 From: monoxgas Date: Wed, 16 Jul 2025 19:40:02 -0600 Subject: [PATCH 4/4] Add set_exception to base spans. Handle tasks wrapping other tasks. --- docs/sdk/main.mdx | 12 ++++++++++++ docs/sdk/task.mdx | 15 +++++++++------ dreadnode/main.py | 12 ++++++++++++ dreadnode/task.py | 4 ++-- dreadnode/tracing/span.py | 19 +++++++++++++++++++ 5 files changed, 54 insertions(+), 8 deletions(-) diff --git a/docs/sdk/main.mdx b/docs/sdk/main.mdx index 5579dee3..3ebd624b 100644 --- a/docs/sdk/main.mdx +++ b/docs/sdk/main.mdx @@ -2074,6 +2074,18 @@ def task( def make_task( func: t.Callable[P, t.Awaitable[R]] | t.Callable[P, R], ) -> Task[P, R]: + if isinstance(func, Task): + return func.with_( + name=name, + label=label, + log_inputs=log_inputs, + log_output=log_output, + log_execution_metrics=log_execution_metrics, + tags=tags, + attributes=attributes, + append=True, + ) + unwrapped = inspect.unwrap(func) if inspect.isgeneratorfunction(unwrapped) or inspect.isasyncgenfunction( diff --git a/docs/sdk/task.mdx b/docs/sdk/task.mdx index aaa4593f..4de2c1fb 100644 --- a/docs/sdk/task.mdx +++ b/docs/sdk/task.mdx @@ -748,8 +748,11 @@ with_( name: str | None = None, tags: Sequence[str] | None = None, label: str | None = None, - log_inputs: Sequence[str] | bool | None = None, - log_output: bool | None = None, + log_inputs: Sequence[str] + | bool + | Inherited + | None = None, + log_output: bool | Inherited | None = None, log_execution_metrics: bool | None = None, append: bool = False, attributes: AnyDict | None = None, @@ -781,12 +784,12 @@ Clone a task and modify its attributes. ) –The new label for the task. * **`log_inputs`** - (`Sequence[str] | bool | None`, default: + (`Sequence[str] | bool | Inherited | None`, default: `None` ) –Log all, or specific, incoming arguments to the function as inputs. * **`log_output`** - (`bool | None`, default: + (`bool | Inherited | None`, default: `None` ) –Log the result of the function as an output. @@ -820,8 +823,8 @@ def with_( name: str | None = None, tags: t.Sequence[str] | None = None, label: str | None = None, - log_inputs: t.Sequence[str] | bool | None = None, - log_output: bool | None = None, + log_inputs: t.Sequence[str] | bool | Inherited | None = None, + log_output: bool | Inherited | None = None, log_execution_metrics: bool | None = None, append: bool = False, attributes: AnyDict | None = None, diff --git a/dreadnode/main.py b/dreadnode/main.py index 14323001..96d0600a 100644 --- a/dreadnode/main.py +++ b/dreadnode/main.py @@ -586,6 +586,18 @@ async def my_task(x: int) -> int: def make_task( func: t.Callable[P, t.Awaitable[R]] | t.Callable[P, R], ) -> Task[P, R]: + if isinstance(func, Task): + return func.with_( + name=name, + label=label, + log_inputs=log_inputs, + log_output=log_output, + log_execution_metrics=log_execution_metrics, + tags=tags, + attributes=attributes, + append=True, + ) + unwrapped = inspect.unwrap(func) if inspect.isgeneratorfunction(unwrapped) or inspect.isasyncgenfunction( diff --git a/dreadnode/task.py b/dreadnode/task.py index b11dd8cc..37bd9b0e 100644 --- a/dreadnode/task.py +++ b/dreadnode/task.py @@ -180,8 +180,8 @@ def with_( name: str | None = None, tags: t.Sequence[str] | None = None, label: str | None = None, - log_inputs: t.Sequence[str] | bool | None = None, - log_output: bool | None = None, + log_inputs: t.Sequence[str] | bool | Inherited | None = None, + log_output: bool | Inherited | None = None, log_execution_metrics: bool | None = None, append: bool = False, attributes: AnyDict | None = None, diff --git a/dreadnode/tracing/span.py b/dreadnode/tracing/span.py index 37620ae4..54898b16 100644 --- a/dreadnode/tracing/span.py +++ b/dreadnode/tracing/span.py @@ -264,6 +264,25 @@ def log_event( attributes=prepare_otlp_attributes(attributes or {}), ) + def set_exception( + self, + exception: BaseException, + *, + attributes: AnyDict | None = None, + status: Status | None = None, + ) -> None: + if self._span is None: + raise ValueError("Span is not active") + + if status is None: + status = Status(StatusCode.ERROR, str(exception)) + + self._span.set_status(status) + self._span.record_exception( + exception, + attributes=prepare_otlp_attributes(attributes or {}), + ) + def __repr__(self) -> str: return ( f"{self.__class__.__name__}(name='{self._span_name}', id={self.span_id},"