Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 0 additions & 20 deletions agent.py

This file was deleted.

42 changes: 3 additions & 39 deletions docs/sdk/task.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ Task(
name: str | None = None,
label: str | None = None,
scorers: ScorersLike[R] | None = None,
assert_scores: list[str] | Literal[True] | None = None,
log_inputs: Sequence[str]
| bool
| Inherited = INHERITED,
Expand All @@ -44,7 +43,6 @@ def __init__(
name: str | None = None,
label: str | None = None,
scorers: ScorersLike[R] | None = None,
assert_scores: list[str] | t.Literal[True] | None = None,
log_inputs: t.Sequence[str] | bool | Inherited = INHERITED,
log_output: bool | Inherited = INHERITED,
log_execution_metrics: bool = False,
Expand Down Expand Up @@ -86,9 +84,6 @@ def __init__(
"The label of the task - used to group associated metrics and data together."
self.scorers = Scorer.fit_like(scorers)
"A list of scorers to evaluate the task's output."
scorer_names = [s.name for s in self.scorers]
self.assert_scores = scorer_names if assert_scores is True else list(assert_scores or [])
"A list of score names to ensure have truthy values, otherwise raise an AssertionFailedError."
self.tags = list(tags or [])
"A list of tags to attach to the task span."
self.attributes = attributes
Expand All @@ -101,29 +96,11 @@ def __init__(
"Log the result of the function as an output."
self.log_execution_metrics = log_execution_metrics
"Track execution metrics such as success rate and run count."

for assertion in self.assert_scores or []:
if assertion not in scorer_names:
raise ValueError(
f"Unknown '{assertion}' in assert_scores, it must be one of {scorer_names}"
)
```


</Accordion>

### assert\_scores

```python
assert_scores = (
scorer_names
if assert_scores is True
else list(assert_scores or [])
)
```

A list of score names to ensure have truthy values, otherwise raise an AssertionFailedError.

### attributes

```python
Expand Down Expand Up @@ -525,6 +502,7 @@ async def run_always(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]:
Returns:
The span associated with task execution.
"""

from dreadnode import score

run = current_run_span.get()
Expand Down Expand Up @@ -623,7 +601,7 @@ async def run_always(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]:

# Score and check assertions

await score(output, self.scorers, assert_scores=self.assert_scores)
await score(output, self.scorers) # assert_scores=self.assert_scores)

if run and self.log_execution_metrics:
run.log_metric(
Expand Down Expand Up @@ -942,9 +920,6 @@ with_(
*,
scorers: Sequence[Scorer[R] | ScorerCallable[R]]
| None = None,
assert_scores: Sequence[str]
| Literal[True]
| None = None,
name: str | None = None,
tags: Sequence[str] | None = None,
label: str | None = None,
Expand All @@ -968,11 +943,6 @@ Clone a task and modify its attributes.
`None`
)
–A list of new scorers to set or append to the task.
* **`assert_scores`**
(`Sequence[str] | Literal[True] | None`, default:
`None`
)
–A list of new assertion names to set or append to the task.
* **`name`**
(`str | None`, default:
`None`
Expand Down Expand Up @@ -1025,7 +995,6 @@ def with_(
self,
*,
scorers: t.Sequence[Scorer[R] | ScorerCallable[R]] | None = None,
assert_scores: t.Sequence[str] | t.Literal[True] | None = None,
name: str | None = None,
tags: t.Sequence[str] | None = None,
label: str | None = None,
Expand All @@ -1040,7 +1009,6 @@ def with_(

Args:
scorers: A list of new scorers to set or append to the task.
assert_scores: A list of new assertion names to set or append to the task.
name: The new name for the task.
tags: A list of new tags to set or append to the task.
label: The new label for the task.
Expand Down Expand Up @@ -1072,19 +1040,15 @@ def with_(

new_scorers = Scorer.fit_like(scorers or [])
new_tags = list(tags or [])
new_assert_scores = (
[s.name for s in new_scorers] if assert_scores is True else list(assert_scores or [])
)

if append:
task.scorers.extend(new_scorers)
task.tags.extend(new_tags)
task.assert_scores.extend(new_assert_scores)
task.attributes.update(attributes or {})
else:
task.scorers = new_scorers
task.tags = new_tags
task.assert_scores = new_assert_scores
# task.assert_scores = new_assert_scores
task.attributes = attributes or {}

return task
Expand Down
8 changes: 4 additions & 4 deletions dreadnode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

from loguru import logger

from dreadnode import agent, convert, data_types, eval, meta, transforms # noqa: A004
from dreadnode import agent, convert, data_types, evals, meta, transforms
from dreadnode.data_types import Audio, Code, Image, Markdown, Object3D, Table, Text, Video
from dreadnode.eval import Eval
from dreadnode.evals import Evaluation
from dreadnode.logging import configure_logging
from dreadnode.main import DEFAULT_INSTANCE, Dreadnode
from dreadnode.meta import (
Expand Down Expand Up @@ -70,7 +70,7 @@
"CurrentTask",
"DatasetField",
"Dreadnode",
"Eval",
"Evaluation",
"Image",
"Markdown",
"Metric",
Expand Down Expand Up @@ -100,7 +100,7 @@
"continue_run",
"convert",
"data_types",
"eval",
"evals",
"get_run_context",
"link_objects",
"log_artifact",
Expand Down
37 changes: 1 addition & 36 deletions dreadnode/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
ToolStart,
_total_usage_from_events,
)
from dreadnode.agent.hooks import retry_with_feedback
from dreadnode.agent.reactions import (
Continue,
Fail,
Expand All @@ -48,7 +47,7 @@
RetryWithFeedback,
)
from dreadnode.agent.result import AgentResult
from dreadnode.agent.stop import StopCondition, stop_never
from dreadnode.agent.stop import StopCondition
from dreadnode.agent.thread import Thread
from dreadnode.agent.tools import AnyTool, Tool, Toolset, discover_tools_on_obj
from dreadnode.agent.types import Message, ToolCall
Expand Down Expand Up @@ -732,37 +731,3 @@ async def run(
raise RuntimeError("Agent run finished unexpectedly.") # noqa: TRY004

return final_event.result


class TaskAgent(Agent):
"""
A specialized agent for running tasks with a focus on completion and reporting.
It extends the base Agent class to provide task-specific functionality.

- Automatically includes the `finish_task`, `give_up_on_task`, and `update_todo` tools.
- Installs a default stop_never condition to trigger stalling behavior when no tools calls are made.
- Uses the `AgentStalled` event to handle stalled tasks by pushing the model to continue or finish the task.
"""

def model_post_init(self, _: t.Any) -> None:
from dreadnode.agent.tools.planning import update_todo
from dreadnode.agent.tools.tasking import finish_task, give_up_on_task

if not any(tool for tool in self.tools if tool.name == "finish_task"):
self.tools.append(finish_task)

if not any(tool for tool in self.tools if tool.name == "give_up_on_task"):
self.tools.append(give_up_on_task)

if not any(tool for tool in self.tools if tool.name == "update_todo"):
self.tools.append(update_todo)

# Force the agent to use finish_task
self.stop_conditions.append(stop_never())
self.hooks.insert(
0,
retry_with_feedback(
event_type=AgentStalled,
feedback="Continue the task if possible or use the 'finish_task' tool to complete it.",
),
)
2 changes: 1 addition & 1 deletion dreadnode/agent/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _handle_tool_start(self, event: ToolStart) -> None:
Text(f"Running [bold]{event.tool_call.name}[/bold]...", style="yellow")
)

def _handle_tool_end(self, event: ToolEnd):
def _handle_tool_end(self, event: ToolEnd) -> None:
"""Prints the tool's result and cleans up the status board."""
# First, print the static result panel. This ensures it's in the
# console history even after the live display is gone.
Expand Down
52 changes: 52 additions & 0 deletions dreadnode/agent/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,58 @@

if t.TYPE_CHECKING:
from dreadnode.agent.agent import Agent
from dreadnode.agent.tools import Toolset


def format_tools_table(tools: "list[Toolset]") -> RenderableType:
"""
Takes a list of Toolset objects and formats them into a concise rich Table.
"""
table = Table(box=box.ROUNDED)
table.add_column("Name", style="orange_red1", no_wrap=True)
table.add_column("Description", min_width=20)
table.add_column("Variant", style="cyan", no_wrap=True)
table.add_column("Methods", style="cyan")

for toolset in tools:
tool_names = ", ".join(tool.name for tool in toolset.get_tools()) if toolset else "-"
table.add_row(
toolset.name,
toolset.__doc__.strip().split("\n")[0] if toolset.__doc__ else "-",
toolset.variant or "-",
tool_names,
)

return table


def format_tool(toolset: "Toolset") -> RenderableType:
"""
Takes a single Toolset and formats its full details into a rich Panel.
"""
details = Table(
box=box.MINIMAL,
show_header=False,
style="orange_red1",
)
details.add_column("Property", style="bold dim", justify="right", no_wrap=True)
details.add_column("Value", style="white")

details.add_row(
Text("Description", justify="right"), toolset.__doc__.strip() if toolset.__doc__ else "-"
)
details.add_row(Text("Variant", justify="right"), toolset.variant or "-")

if toolset.get_tools():
tool_names = ", ".join(f"[cyan]{tool.name}[/]" for tool in toolset.get_tools())
details.add_row(Text("Methods", justify="right"), tool_names)

return Panel(
details,
title=f"[bold]{toolset.name}[/]",
title_align="left",
border_style="orange_red1",
)


def format_agents_table(agents: "list[Agent]") -> RenderableType:
Expand Down
8 changes: 0 additions & 8 deletions dreadnode/agent/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import importlib
import typing as t

from dreadnode.agent.tools import planning, reporting, tasking
from dreadnode.agent.tools.base import (
AnyTool,
Tool,
Expand All @@ -11,18 +10,11 @@
tool_method,
)

if t.TYPE_CHECKING:
from dreadnode.agent.tools import fs

__all__ = [
"AnyTool",
"Tool",
"Toolset",
"discover_tools_on_obj",
"fs",
"planning",
"reporting",
"tasking",
"tool",
"tool_method",
]
Expand Down
Loading
Loading