diff --git a/docs/usage/runs.mdx b/docs/usage/runs.mdx index 92023afa..266ee8d8 100644 --- a/docs/usage/runs.mdx +++ b/docs/usage/runs.mdx @@ -263,6 +263,41 @@ with dn.run("risky-experiment"): dn.log_metric("error", 1.0) ``` +## Task Hierarchy and Relationships + +Every run maintains a list of its top-level tasks (tasks that are called directly within the run context, not as children of other tasks): + +```python +with dn.run("data-processing") as run: + # These tasks are added to the run's task list + cleaning_result = await clean_data.run(raw_data) + analysis_result = await analyze_data.run(cleaned_data) + +# Access all top-level tasks in the run +print(f"Run has {len(run.tasks)} top-level tasks") + +for task in run.tasks: + print(f"Task: {task}") + print(f"Run: {task.run}") + print(f"Child tasks: {len(task.tasks)}") +``` + +You can recursively gather all tasks within a run using `.all_tasks` and perform analysis on them: + +```python +with dn.run("comprehensive-analysis") as run: + await complex_workflow(data) + + all_tasks = run.all_tasks + print(f"Run '{run.run_id}' executed {len(all_tasks)} total tasks") + + # Analyze execution times, success rates, etc. + successful_tasks = [t for t in all_tasks if not t.failed] + success_rate = len(successful_tasks) / len(all_tasks) if all_tasks else 0.0 + print(f"Success rate: {success_rate * 100:.1f}%") + dn.log_metric("success_rate", success_rate) +``` + ## Best Practices 1. **Use meaningful names**: Give your runs descriptive names that indicate their purpose. @@ -270,3 +305,4 @@ with dn.run("risky-experiment"): 3. **Create separate runs for separate experiments**: Don't try to jam multiple experiments into a single run—you can create multiple runs inside your code. 4. **Use projects for organization**: Group related runs into projects. 5. **Create comparison runs**: When testing different approaches, ensure parameters and metrics are consistent to enable meaningful comparison. +6. **Leverage task hierarchy**: Organize complex workflows using hierarchical tasks within runs to maintain clear execution structure. diff --git a/docs/usage/tasks.mdx b/docs/usage/tasks.mdx index fc38c2ca..8fe2e77c 100644 --- a/docs/usage/tasks.mdx +++ b/docs/usage/tasks.mdx @@ -329,6 +329,47 @@ When this task logs a metric named `token_count`, that metric is: 1. Stored with the task span as `token_count` 2. Mirrored at the run level with the prefix `tokenize.token_count` +## Task Hierarchy and Relationships + +Every task maintains relationships with its parent task (if any) and subtasks (if any). These relationships are automatically established when tasks are called within other tasks: + +```python +import dreadnode as dn + +@dn.task() +async def process(data: str) -> str: + return f"processed: {data}" + +@dn.task() +async def finalize(data: str) -> str: + return f"finalized: {data}" + +@dn.task() +async def parent_task(data: str) -> str: + processed = await process(data) + finalized = await finalize(processed) + return finalized + +with dn.run("workflow-example"): + parent = await parent_task.run("input_data") + + print(len(parent.tasks)) # 2 + + # Iterate through child tasks + for task in parent.tasks: + print(f"{task!r}") + +# TaskSpan(name='process', label='process', run_id='...', parent_task='...', ...) +# TaskSpan(name='finalize', label='finalize', run_id='...', parent_task='...', ...) +``` + +The available hierarchy properties include: + +- **`task_span.tasks`**: List of child `TaskSpan` objects +- **`task_span.all_tasks`**: Flat list of all tasks under this task, including subtasks +- **`task_span.parent_task`**: Reference to parent `TaskSpan` (or `None` for top-level tasks) +- **`task_span.run`**: Reference to the `RunSpan` containing this task + ## Best Practices 1. **Keep tasks focused**: Each task should do one thing well, making it easier to trace and debug. @@ -336,4 +377,5 @@ When this task logs a metric named `token_count`, that metric is: 3. **Log relevant data**: Be intentional about what you log as inputs, outputs, and metrics. 4. **Handle errors appropriately**: Use `try_run()` and similar methods to handle task failures gracefully. 5. **Use tasks to structure your code**: Tasks help create natural boundaries in your application. -6. **Combine with [Rigging tools](/open-source/rigging/topics/tools)**: Tasks work seamlessly with Rigging tools for LLM agents. +6. **Leverage task hierarchy**: Use parent-child relationships to organize complex workflows and enable detailed analysis. +7. **Combine with [Rigging tools](/open-source/rigging/topics/tools)**: Tasks work seamlessly with Rigging tools for LLM agents. diff --git a/dreadnode/__init__.py b/dreadnode/__init__.py index 8ea2178d..6dc3320e 100644 --- a/dreadnode/__init__.py +++ b/dreadnode/__init__.py @@ -1,3 +1,4 @@ +from dreadnode import convert, data_types from dreadnode.data_types import Audio, Image, Object3D, Table, Video from dreadnode.main import DEFAULT_INSTANCE, Dreadnode from dreadnode.metric import Metric, MetricDict, Scorer @@ -33,6 +34,7 @@ __version__ = VERSION __all__ = [ + "DEFAULT_INSTANCE", "Audio", "Dreadnode", "Image", @@ -51,6 +53,10 @@ "__version__", "api", "configure", + "continue_run", + "convert", + "data_types", + "get_run_context", "link_objects", "log_artifact", "log_input", diff --git a/dreadnode/convert.py b/dreadnode/convert.py new file mode 100644 index 00000000..2d7f9e48 --- /dev/null +++ b/dreadnode/convert.py @@ -0,0 +1,42 @@ +import typing as t + +if t.TYPE_CHECKING: + import networkx as nx # type: ignore [import-untyped] + + from dreadnode.tracing.span import RunSpan + + +def run_span_to_graph(run: "RunSpan") -> "nx.DiGraph": + try: + import networkx as nx + except ImportError as e: + raise RuntimeError("The `networkx` package is required for graph conversion") from e + + graph = nx.DiGraph() + + graph.add_node( + run.run_id, + name=run.name, + label=run.label, + start_time=run.start_time, + end_time=run.end_time, + duration=run.duration, + status="failed" if run.failed else "running" if run.is_recording else "completed", + tags=run.tags, + ) + + for task in run.all_tasks: + graph.add_node( + task.span_id, + name=task.name, + label=task.label, + start_time=task.start_time, + end_time=task.end_time, + duration=task.duration, + status="failed" if task.failed else "running" if task.active else "completed", + tags=task.tags, + ) + + graph.add_edge(task.parent_task.span_id if task.parent_task else run.run_id, task.span_id) + + return graph diff --git a/dreadnode/tracing/span.py b/dreadnode/tracing/span.py index 665d219c..5b9b04b2 100644 --- a/dreadnode/tracing/span.py +++ b/dreadnode/tracing/span.py @@ -24,6 +24,7 @@ from opentelemetry.sdk.trace import ReadableSpan from opentelemetry.trace import Tracer from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator +from opentelemetry.trace.status import Status, StatusCode from opentelemetry.util import types as otel_types from ulid import ULID @@ -31,6 +32,7 @@ from dreadnode.artifact.storage import ArtifactStorage from dreadnode.artifact.tree_builder import ArtifactTreeBuilder, DirectoryNode from dreadnode.constants import MAX_INLINE_OBJECT_BYTES +from dreadnode.convert import run_span_to_graph from dreadnode.metric import Metric, MetricAggMode, MetricsDict from dreadnode.object import Object, ObjectRef, ObjectUri, ObjectVal from dreadnode.serialization import Serialized, serialize @@ -68,6 +70,9 @@ SpanType, ) +if t.TYPE_CHECKING: + import networkx as nx # type: ignore [import-untyped] + logger = logging.getLogger(__name__) R = t.TypeVar("R") @@ -83,6 +88,15 @@ ) +def _format_status(status: Status) -> str: + """Format the status for display.""" + if status.status_code == StatusCode.ERROR: + if status.description is None: + return "'error'" + return f"'error - {status.description}'" + return "'ok'" + + class Span(ReadableSpan): def __init__( self, @@ -172,12 +186,36 @@ def trace_id(self) -> str: raise ValueError("Span is not active") return trace_api.format_trace_id(self._span.get_span_context().trace_id) + @property + def label(self) -> str: + """Get the label of the span.""" + return self._label + @property def is_recording(self) -> bool: + """Check if the span is currently recording.""" if self._span is None: return False return self._span.is_recording() + @property + def active(self) -> bool: + """Check if the span is currently active (recording).""" + return self._span is not None and self._span.is_recording() + + @property + def failed(self) -> bool: + """Check if the span has failed.""" + return self.status.status_code == StatusCode.ERROR + + @property + def duration(self) -> float: + """Get the duration of the span in seconds.""" + if self._span is None: + return 0.0 + end_time = self.end_time or time.time() + return (end_time - self.start_time) if self.start_time else 0.0 + def set_tags(self, tags: t.Sequence[str]) -> None: tags = [tags] if isinstance(tags, str) else list(tags) tags = [clean_str(t) for t in tags] @@ -226,6 +264,15 @@ def log_event( attributes=prepare_otlp_attributes(attributes or {}), ) + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(name='{self._span_name}', id={self.span_id}," + f"label='{self._label}', status={_format_status(self.status)}, active={self.is_recording})" + ) + + def __str__(self) -> str: + return f"{self._span_name} ({self._label})" if self._label else self._span_name + class RunContext(te.TypedDict): """Context for transferring and continuing runs in other places.""" @@ -272,6 +319,16 @@ def __init__( super().__init__(f"run.{run_id}.update", attributes, tracer, type="run_update") + def __repr__(self) -> str: + status = "active" if self.is_recording else "inactive" + run_id = self.get_attribute(SPAN_ATTRIBUTE_RUN_ID, "unknown") + project = self.get_attribute(SPAN_ATTRIBUTE_PROJECT, "unknown") + return f"RunUpdateSpan(run_id='{run_id}', project='{project}', status={status})" + + def __str__(self) -> str: + run_id = self.get_attribute(SPAN_ATTRIBUTE_RUN_ID, "unknown") + return f"run.{run_id}.update" + class RunSpan(Span): def __init__( @@ -324,6 +381,8 @@ def __init__( self._file_system = file_system self._prefix_path = prefix_path + self._tasks: list[TaskSpan[t.Any]] = [] + attributes = { SPAN_ATTRIBUTE_RUN_ID: str(run_id or ULID()), SPAN_ATTRIBUTE_PROJECT: project, @@ -467,6 +526,19 @@ def push_update(self, *, force: bool = False) -> None: def run_id(self) -> str: return str(self.get_attribute(SPAN_ATTRIBUTE_RUN_ID, "")) + @property + def tasks(self) -> "list[TaskSpan[t.Any]]": + return self._tasks + + @property + def all_tasks(self) -> "list[TaskSpan[t.Any]]": + """Get all tasks, including subtasks.""" + all_tasks = [] + for task in self._tasks: + all_tasks.append(task) + all_tasks.extend(task.all_tasks) + return all_tasks + def log_object( self, value: t.Any, @@ -731,6 +803,25 @@ def log_output( self._outputs.append(object_ref) self._pending_outputs.append(object_ref) + def to_graph(self) -> "nx.DiGraph": + return run_span_to_graph(self) + + def __repr__(self) -> str: + run_id = self.run_id + project = self.project + num_tasks = len(self._tasks) + num_objects = len(self._objects) + return ( + f"RunSpan(name='{self.name}', id='{run_id}', " + f"project='{project}', status={_format_status(self.status)}, active={self.is_recording}, " + f"tasks={num_tasks}, objects={num_objects})" + ) + + def __str__(self) -> str: + if self._label: + return f"{self.name} ({self._label}) - {self.run_id}" + return f"{self.name} - {self.run_id}" + class TaskSpan(Span, t.Generic[R]): def __init__( @@ -752,6 +843,9 @@ def __init__( self._context_token: Token[TaskSpan[t.Any] | None] | None = None # contextvars context + self._tasks: list[TaskSpan[t.Any]] = [] + self._parent_task: TaskSpan[t.Any] | None = None + attributes = { SPAN_ATTRIBUTE_RUN_ID: str(run_id), SPAN_ATTRIBUTE_INPUTS: self._inputs, @@ -762,14 +856,17 @@ def __init__( super().__init__(name, attributes, tracer, type="task", label=label, tags=tags) def __enter__(self) -> te.Self: - self._parent_task = current_task_span.get() - if self._parent_task is not None: - self.set_attribute(SPAN_ATTRIBUTE_PARENT_TASK_ID, self._parent_task.span_id) - self._run = current_run_span.get() if self._run is None: raise RuntimeError("You cannot start a task span without a run") + self._parent_task = current_task_span.get() + if self._parent_task is not None: + self.set_attribute(SPAN_ATTRIBUTE_PARENT_TASK_ID, self._parent_task.span_id) + self._parent_task._tasks.append(self) # noqa: SLF001 + else: + self._run._tasks.append(self) # noqa: SLF001 + self._context_token = current_task_span.set(self) return super().__enter__() @@ -788,14 +885,36 @@ def __exit__( @property def run_id(self) -> str: + """Get the run id this task is associated with (may be empty).""" return str(self.get_attribute(SPAN_ATTRIBUTE_RUN_ID, "")) @property def parent_task_id(self) -> str: + """Get the parent task ID if it exists (may be empty).""" return str(self.get_attribute(SPAN_ATTRIBUTE_PARENT_TASK_ID, "")) + @property + def parent_task(self) -> "TaskSpan[t.Any] | None": + """Get the parent task if it exists.""" + return self._parent_task + + @property + def tasks(self) -> list["TaskSpan[t.Any]"]: + """Get the list of children tasks.""" + return self._tasks + + @property + def all_tasks(self) -> list["TaskSpan[t.Any]"]: + """Get all tasks, including subtasks.""" + all_tasks = [] + for task in self._tasks: + all_tasks.append(task) + all_tasks.extend(task.all_tasks) + return all_tasks + @property def run(self) -> RunSpan: + """Get the run this task is associated with.""" if self._run is None: raise ValueError("Task span is not in an active run") return self._run @@ -924,6 +1043,25 @@ def get_average_metric_value(self, key: str | None = None) -> float: metrics, ) + def __repr__(self) -> str: + run_id = self.run_id + parent_task_id = self.parent_task_id + num_subtasks = len(self._tasks) + num_inputs = len(self._inputs) + num_outputs = len(self._outputs) + + parent_info = f", parent_task='{parent_task_id}'" if parent_task_id else "" + return ( + f"TaskSpan(name='{self.name}', label='{self._label}', " + f"run='{run_id}'{parent_info}, status={_format_status(self.status)}, active={self.is_recording}, " + f"tasks={num_subtasks}, inputs={num_inputs}, outputs={num_outputs})" + ) + + def __str__(self) -> str: + if self._label and self._label != self.name: + return f"{self.name} ({self._label})" + return self.name + def prepare_otlp_attributes( attributes: AnyDict,