Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 21 additions & 40 deletions src/agentpool/agents/native_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
from uuid import uuid4

import logfire
from pydantic_ai import Agent as PydanticAgent, CallToolsNode, ModelRequestNode, RunContext
from pydantic_ai import Agent as PydanticAgent, CallToolsNode, ModelRequestNode
from pydantic_ai.models import Model
from pydantic_ai.tools import ToolDefinition

from agentpool.agents.base_agent import BaseAgent
from agentpool.agents.context import AgentContext
from agentpool.agents.events import RunStartedEvent, StreamCompleteEvent
from agentpool.agents.events.processors import FileTracker
from agentpool.agents.exceptions import UnknownCategoryError, UnknownModeError
Expand All @@ -26,7 +26,6 @@
from agentpool.storage import StorageManager
from agentpool.tools import Tool, ToolManager
from agentpool.tools.exceptions import ToolError
from agentpool.utils.inspection import get_argument_key
from agentpool.utils.result_utils import to_type
from agentpool.utils.streams import merge_queue_into_iterator

Expand All @@ -47,7 +46,6 @@
from toprompt import AnyPromptType
from upathtools import JoinablePathLike

from agentpool.agents.context import AgentContext
from agentpool.agents.events import RichAgentStreamEvent
from agentpool.agents.modes import ModeCategory
from agentpool.common_types import (
Expand Down Expand Up @@ -606,7 +604,7 @@ async def get_agentlet[AgentOutputType](
model: ModelType | None,
output_type: type[AgentOutputType] | None,
input_provider: InputProvider | None = None,
) -> PydanticAgent[TDeps, AgentOutputType]:
) -> PydanticAgent[AgentContext[TDeps], Any]:
"""Create pydantic-ai agent from current state."""
from agentpool.agents.native_agent.tool_wrapping import wrap_tool

Expand All @@ -617,51 +615,30 @@ async def get_agentlet[AgentOutputType](
model_, _settings = self._resolve_model_string(actual_model)
else:
model_ = actual_model
agent = PydanticAgent(

context_for_tools = self.get_context(input_provider=input_provider)

# Collect pydantic_ai.tools.Tool instances using Tool.to_pydantic_ai()
pydantic_ai_tools = []
for tool in tools:
wrapped = wrap_tool(tool, context_for_tools, hooks=self._hook_manager)
pydantic_ai_tool = tool.to_pydantic_ai(function_override=wrapped)
pydantic_ai_tools.append(pydantic_ai_tool)

return PydanticAgent(
name=self.name,
model=model_,
model_settings=self.model_settings,
instructions=self._formatted_system_prompt,
retries=self._retries,
end_strategy=self._end_strategy,
output_retries=self._output_retries,
deps_type=self.deps_type or NoneType,
deps_type=AgentContext[TDeps],
output_type=final_type,
tools=pydantic_ai_tools,
builtin_tools=self._builtin_tools,
)

context_for_tools = self.get_context(input_provider=input_provider)

for tool in tools:
wrapped = wrap_tool(tool, context_for_tools, hooks=self._hook_manager)

prepare_fn = None
if tool.schema_override:

def create_prepare(
t: Tool,
) -> Callable[[RunContext[Any], ToolDefinition], Awaitable[ToolDefinition | None]]:
async def prepare_schema(
ctx: RunContext[Any], tool_def: ToolDefinition
) -> ToolDefinition | None:
if not t.schema_override:
return None
return ToolDefinition(
name=t.schema_override.get("name") or t.name,
description=t.schema_override.get("description") or t.description,
parameters_json_schema=t.schema_override.get("parameters"),
)

return prepare_schema

prepare_fn = create_prepare(tool)

if get_argument_key(wrapped, RunContext):
agent.tool(prepare=prepare_fn)(wrapped)
else:
agent.tool_plain(prepare=prepare_fn)(wrapped)
return agent # type: ignore[return-value]

async def _stream_events(
self,
prompts: list[UserContent],
Expand Down Expand Up @@ -692,9 +669,13 @@ async def _stream_events(
# Prepend pending context parts (prompts are already pydantic-ai UserContent format)
# Track tool call starts to combine with results later
file_tracker = FileTracker()
# Create AgentContext with user deps stored in .data
agent_deps = self.get_context(input_provider=input_provider)
if deps is not None:
agent_deps.data = deps
async with agentlet.iter(
prompts,
deps=deps, # type: ignore[arg-type]
deps=agent_deps,
message_history=[m for run in history_list for m in run.to_pydantic_ai()],
usage_limits=self._default_usage_limits,
) as agent_run:
Expand Down
2 changes: 1 addition & 1 deletion src/agentpool/agents/native_agent/tool_wrapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ async def wrapped( # pyright: ignore[reportRedeclaration]
if result == "allow":
# Populate AgentContext with RunContext data if needed
if agent_ctx.data is None:
agent_ctx.data = ctx.deps
agent_ctx.data = ctx.deps.data if ctx.deps else ctx.deps

if agent_ctx_key: # inject AgentContext
# Build model_name from RunContext's model (provider:model_name format)
Expand Down
209 changes: 203 additions & 6 deletions src/agentpool/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@
from collections.abc import Awaitable, Callable

from mcp.types import Tool as MCPTool, ToolAnnotations
from pydantic_ai import UserContent
from pydantic_ai import RunContext, UserContent
from pydantic_ai.tools import ToolDefinition
from schemez import FunctionSchema, Property

from agentpool.agents.context import AgentContext
from agentpool.common_types import ToolSource
from agentpool.tools.manager import ToolState

Expand Down Expand Up @@ -82,6 +84,15 @@ class Tool[TOutputType = Any]:
schema_override: schemez.OpenAIFunctionDefinition | None = None
"""Schema override. If not set, the schema is inferred from the callable."""

prepare: (
Callable[[RunContext[AgentContext], ToolDefinition], Awaitable[ToolDefinition | None]]
| None
) = None
"""Prepare function for tool schema customization."""

function_schema: Any | None = None
"""Function schema override for pydantic-ai tools."""

hints: ToolHints = field(default_factory=ToolHints)
"""Hints for the tool."""

Expand Down Expand Up @@ -113,18 +124,190 @@ class Tool[TOutputType = Any]:

@abstractmethod
def get_callable(self) -> Callable[..., TOutputType | Awaitable[TOutputType]]:
"""Get the callable for this tool. Subclasses must implement."""
"""Get callable for this tool. Subclasses must implement."""
...

def to_pydantic_ai(self) -> PydanticAiTool:
"""Convert tool to Pydantic AI tool."""
metadata = {**self.metadata, "agent_name": self.agent_name, "category": self.category}
def _get_effective_prepare(
self,
) -> (
Callable[[RunContext[AgentContext], ToolDefinition], Awaitable[ToolDefinition | None]]
| None
):
"""Get the effective prepare function for this tool.

Returns self.prepare if set.

Returns:
Prepare function or None.
"""
return self.prepare

def _detect_takes_ctx(self, func: Callable[..., Any] | None = None) -> bool:
"""Detect if function takes RunContext parameter.

Args:
func: The callable to inspect. If None, uses self.get_callable().

Returns:
True if function has a RunContext parameter, False otherwise.
"""
if func is None:
func = self.get_callable()

# Check for RunContext in function signature
sig = inspect.signature(func)
for param in sig.parameters.values():
# Check by string type name (works across TYPE_CHECKING)
if param.annotation == "RunContext" or (
hasattr(param.annotation, "__name__") and param.annotation.__name__ == "RunContext"
):
return True
return False

def _get_json_schema(self, func: Callable[..., Any] | None = None) -> dict[str, Any] | None:
"""Get effective JSON schema for this tool.

Returns a JSON schema dict if a custom schema is needed
(from schema_override or fallback to schemez), or None if
pydantic-ai should infer the schema automatically.

Args:
func: The callable to use for schema generation. If None, uses self.get_callable().

Returns:
JSON schema dict or None.
"""
if func is None:
func = self.get_callable()

# If no schema_override, let pydantic-ai infer the schema
if self.schema_override is None:
return None

# Try primary path with pydantic_ai.function_schema
try:
from pydantic_ai._function_schema import ( # type: ignore[attr-defined]
GenerateJsonSchema,
function_schema,
)

schema = function_schema(func, schema_generator=GenerateJsonSchema)

# Apply schema_override to generated schema
# Merge top-level description
if "description" in self.schema_override:
schema.json_schema["description"] = self.schema_override["description"]

if "parameters" in self.schema_override:
override_params = self.schema_override["parameters"]
# Merge custom parameter definitions (which include descriptions)
if "properties" in override_params:
for param_name, param_def in override_params["properties"].items():
if param_name in schema.json_schema.get("properties", {}):
# Update existing parameter with custom description
schema.json_schema["properties"][param_name].update(param_def)
else:
# Add new parameter
schema.json_schema.setdefault("properties", {})[param_name] = param_def
except Exception as e:
# Fallback to schemez if pydantic_ai.function_schema fails
from pydantic.errors import PydanticUndefinedAnnotation

if isinstance(e, (PydanticUndefinedAnnotation, NameError)):
logger.warning(
"pydantic_ai.function_schema failed for %s, falling back to schemez: %s",
self.name,
str(e),
)
else:
raise

# Fallback: use schemez to generate schema
from pydantic_ai import RunContext

from agentpool.agents.context import AgentContext

# Use schema_override description if provided, otherwise use self.description
desc = (
self.schema_override.get("description", self.description)
if self.schema_override
else self.description
)

# Use schemez to generate JSON schema
schema = schemez.create_schema( # type: ignore
func,
name_override=self.name,
description_override=desc,
exclude_types=[AgentContext, RunContext],
)

# Return only the parameters part (the "object" schema)
# Use model_dump - schemez.FunctionSchema has this method (pydantic-compatible)
schema_dump = getattr(schema, "model_dump")() # noqa: B009, type: ignore[attr-defined]
generated_params = schema_dump["parameters"]

# Apply parameter overrides to maintain consistency with the primary path
if "parameters" in self.schema_override:
override_params = self.schema_override["parameters"]
if "properties" in override_params:
for param_name, param_def in override_params["properties"].items():
if param_name in generated_params.get("properties", {}):
generated_params["properties"][param_name].update(param_def)
else:
generated_params.setdefault("properties", {})[param_name] = param_def
return generated_params # type: ignore[no-any-return]
else:
return schema.json_schema

def to_pydantic_ai(
self, function_override: Callable[..., TOutputType | Awaitable[TOutputType]] | None = None
) -> PydanticAiTool:
"""Convert tool to Pydantic AI tool.

Args:
function_override: Optional callable to override self.get_callable().

Returns:
PydanticAiTool instance configured for this tool.
"""
base_metadata = self.metadata or {}
metadata = {
**base_metadata,
"agent_name": self.agent_name,
"category": self.category,
}
function = function_override if function_override is not None else self.get_callable()

# Check if we have a custom JSON schema that needs to be used
json_schema = self._get_json_schema(function)

# If we have a custom schema, use Tool.from_schema
if json_schema is not None:
# Detect if function takes RunContext parameter
takes_ctx = self._detect_takes_ctx(function)

# Import Tool.from_schema at runtime to avoid circular imports
from pydantic_ai.tools import Tool as PydanticAiToolClass

tool_instance = PydanticAiToolClass.from_schema(
function=function,
name=self.name,
description=self.description,
json_schema=json_schema,
takes_ctx=takes_ctx,
)
# Tool.from_schema doesn't accept prepare parameter, assign it manually
tool_instance.prepare = self._get_effective_prepare() # type: ignore[assignment]
return tool_instance
# No custom schema, let pydantic-ai infer it automatically
return PydanticAiTool(
function=self.get_callable(),
function=function,
name=self.name,
description=self.description,
requires_approval=self.requires_confirmation,
metadata=metadata,
prepare=self._get_effective_prepare(), # type: ignore[arg-type]
)

@property
Expand Down Expand Up @@ -235,6 +418,11 @@ def from_callable(
name_override: str | None = None,
description_override: str | None = None,
schema_override: schemez.OpenAIFunctionDefinition | None = None,
prepare: (
Callable[[RunContext[AgentContext], ToolDefinition], Awaitable[ToolDefinition | None]]
| None
) = None,
function_schema: Any | None = None,
hints: ToolHints | None = None,
category: ToolKind | None = None,
enabled: bool = True,
Expand All @@ -247,6 +435,8 @@ def from_callable(
name_override=name_override,
description_override=description_override,
schema_override=schema_override,
prepare=prepare,
function_schema=function_schema,
hints=hints,
category=category,
enabled=enabled,
Expand Down Expand Up @@ -298,6 +488,11 @@ def from_callable(
name_override: str | None = None,
description_override: str | None = None,
schema_override: schemez.OpenAIFunctionDefinition | None = None,
prepare: (
Callable[[RunContext[AgentContext], ToolDefinition], Awaitable[ToolDefinition | None]]
| None
) = None,
function_schema: Any | None = None,
hints: ToolHints | None = None,
category: ToolKind | None = None,
enabled: bool = True,
Expand Down Expand Up @@ -327,6 +522,8 @@ def from_callable(
callable=callable_obj, # pyright: ignore[reportArgumentType]
import_path=import_path,
schema_override=schema_override,
prepare=prepare,
function_schema=function_schema,
category=category,
hints=hints or ToolHints(),
enabled=enabled,
Expand Down
Loading
Loading