diff --git a/docs/advanced/system-prompts.md b/docs/advanced/system-prompts.md index 5989302b1..cc21f0c84 100644 --- a/docs/advanced/system-prompts.md +++ b/docs/advanced/system-prompts.md @@ -135,6 +135,121 @@ def generate_context_prompt(user_id: str, session_type: str) -> str: return f"You are assisting {user_data.name} with {session_type}." ``` +### Provider-Based Dynamic Instructions + +ResourceProviders can provide dynamic instructions that are re-evaluated on each agent run with access to runtime context. This is different from function-generated prompts because instructions have access to AgentContext and RunContext. + +```yaml +agents: + context_aware_agent: + system_prompt: + - "You are a helpful assistant." + + toolsets: + - type: custom + import_path: myapp.providers.UserContextProvider + name: user_provider + + # Add provider-based dynamic instructions + instructions: + - type: provider + ref: user_provider +``` + +Provider implementation: + +```python +from agentpool.resource_providers import ResourceProvider +from agentpool.prompts.instructions import InstructionFunc +from agentpool.agents.context import AgentContext + +class UserContextProvider(ResourceProvider): + async def get_instructions(self) -> list[InstructionFunc]: + """Return dynamic instruction functions. + + Each function is re-evaluated on each run with access + to runtime context (AgentContext, RunContext, or both). + """ + return [ + self._get_user_context, # With AgentContext + self._get_system_status, # No context + ] + + async def _get_user_context(self, ctx: AgentContext) -> str: + """Generate context based on agent state.""" + # Access agent name, model, conversation history, etc. + return f"Agent: {ctx.name}, Model: {ctx.model_name}" + + def _get_system_status(self) -> str: + """Return static instruction.""" + return "System: Online" +``` + +#### Instruction Function Context Types + +Instruction functions can receive different context types: + +- **No context**: `() -> str` +- **AgentContext only**: `(AgentContext) -> str` +- **RunContext only**: `(RunContext) -> str` +- **Both contexts**: `(AgentContext, RunContext) -> str` + +```python +# No context +def simple() -> str: + return "Be helpful." + +# AgentContext only +async def with_agent(ctx: AgentContext) -> str: + return f"Agent: {ctx.name}" + +# RunContext only +async def with_run(ctx: RunContext) -> str: + return f"Model: {ctx.model_name}" + +# Both contexts +async def with_both(agent_ctx: AgentContext, run_ctx: RunContext) -> str: + return f"Agent {agent_ctx.name} using {run_ctx.model.model_name}" +``` + +#### Function-Generated vs Provider-Based Instructions + +| Feature | Function-Generated | Provider-Based | +|---------|-------------------|-----------------| +| **Location** | In `system_prompt` field | In `instructions` field | +| **Context Access** | No runtime context | AgentContext, RunContext, or both | +| **Re-evaluation** | Evaluated once at agent start | Re-evaluated on each run | +| **Best For** | Simple dynamic content | Context-aware instructions | + +#### Order of Instructions + +Instructions are processed in this order: + +1. **Static system prompts** (from `system_prompt` field) +2. **Provider instructions** (in order defined in `instructions` list) + +```yaml +# Resulting instruction order: +instructions: + - "You are an expert." # 1 + - type: provider + ref: provider_a # 2 + - "Follow these guidelines:" # 3 + - type: provider + ref: provider_b # 4 +``` + +#### Error Handling + +If an instruction function fails, the error is logged and the instruction is skipped. Agent initialization continues without crashing. + +!!! tip "Use Provider-Based Instructions When" + You need access to runtime state (AgentContext, RunContext) or want instructions that change on each run based on context like conversation history, available tools, or session state. + +!!! note "See Also" + - [Dynamic Instructions Example](../../examples/dynamic-instructions/) + - [Resource Providers](../configuration/resources.md) + ## Callable Prompts System prompts can include callables that are evaluated when the agent context starts: diff --git a/docs/configuration/resources.md b/docs/configuration/resources.md index 2884588f6..ba7341fb6 100644 --- a/docs/configuration/resources.md +++ b/docs/configuration/resources.md @@ -54,3 +54,100 @@ Resources are loaded on-demand when agents request them, supporting parameteriza - Source resources automatically extract docstrings and type hints - LangChain resources leverage the extensive LangChain loader ecosystem - Callable resources provide maximum flexibility for custom logic + +## Dynamic Instructions from Resource Providers + +ResourceProviders can now provide dynamic instructions that are re-evaluated on each agent run. This allows providers to generate context-aware instructions based on runtime state. + +### How It Works + +ResourceProviders can implement the `get_instructions()` method to return instruction functions: + +```python +from agentpool.resource_providers import ResourceProvider +from agentpool.prompts.instructions import InstructionFunc +from agentpool.agents.context import AgentContext + +class MyProvider(ResourceProvider): + async def get_instructions(self) -> list[InstructionFunc]: + """Return dynamic instruction functions.""" + return [ + self._get_static_instruction, # No context + self._get_context_instruction, # With AgentContext + ] + + def _get_static_instruction(self) -> str: + """Instruction without context access.""" + return "Always be helpful." + + async def _get_context_instruction(self, ctx: AgentContext) -> str: + """Instruction with context access.""" + return f"Agent: {ctx.name}, Model: {ctx.model_name}" +``` + +### YAML Configuration + +Configure providers to provide instructions using the `instructions` field: + +```yaml +agents: + my_agent: + type: native + model: openai:gpt-4o + toolsets: + - type: custom + import_path: myapp.providers.MyProvider + name: my_provider + + # Add provider-based instructions + instructions: + - type: provider + ref: my_provider +``` + +### Instruction Function Types + +Instruction functions can accept different context types: + +- **No context**: `() -> str` +- **AgentContext only**: `(AgentContext) -> str` +- **RunContext only**: `(RunContext) -> str` +- **Both contexts**: `(AgentContext, RunContext) -> str` + +```python +# No context +def simple() -> str: + return "Be helpful." + +# AgentContext only +async def with_agent(ctx: AgentContext) -> str: + return f"Agent: {ctx.name}" + +# RunContext only +async def with_run(ctx: RunContext) -> str: + return f"Model: {ctx.model.model_name}" + +# Both contexts +async def with_both(agent_ctx: AgentContext, run_ctx: RunContext) -> str: + return f"Agent {agent_ctx.name} using {run_ctx.model.model_name}" +``` + +### Benefits + +- **Context-aware**: Instructions adapt to runtime state (conversation history, tools used, etc.) +- **Per-run re-evaluation**: Unlike static prompts, dynamic instructions regenerate on each run +- **Provider integration**: Toolsets and other providers can inject their own contextual instructions +- **Flexible context access**: Choose what context you need (AgentContext, RunContext, or both) + +### Error Handling + +If an instruction function fails: +- Error is logged with context +- Agent initialization continues +- Failed instruction is skipped (uses empty string fallback) + +### See Also + +- [Dynamic Instructions Example](../../examples/dynamic-instructions/) +- [ResourceProvider Base Class](../api/resource_providers.md) +- [Instruction Types](../api/instructions.md) diff --git a/docs/examples/mcp_sampling_elicitation/demo.py b/docs/examples/mcp_sampling_elicitation/demo.py index d0840d1aa..36e4f9c7b 100644 --- a/docs/examples/mcp_sampling_elicitation/demo.py +++ b/docs/examples/mcp_sampling_elicitation/demo.py @@ -6,9 +6,10 @@ from __future__ import annotations -import anyio from pathlib import Path +import anyio + from agentpool import Agent from agentpool_config.mcp_server import StdioMCPServerConfig diff --git a/src/agentpool/agents/native_agent/agent.py b/src/agentpool/agents/native_agent/agent.py index b6ef25f83..b1fdd9305 100644 --- a/src/agentpool/agents/native_agent/agent.py +++ b/src/agentpool/agents/native_agent/agent.py @@ -12,9 +12,8 @@ from uuid import uuid4 import logfire -from pydantic_ai import Agent as PydanticAgent, CallToolsNode, ModelRequestNode, RunContext +from pydantic_ai import Agent as PydanticAgent, CallToolsNode, ModelRequestNode from pydantic_ai.models import Model -from pydantic_ai.tools import ToolDefinition from agentpool.agents.base_agent import BaseAgent from agentpool.agents.events import RunStartedEvent, StreamCompleteEvent @@ -26,7 +25,6 @@ from agentpool.storage import StorageManager from agentpool.tools import Tool, ToolManager from agentpool.tools.exceptions import ToolError -from agentpool.utils.inspection import get_argument_key from agentpool.utils.result_utils import to_type from agentpool.utils.streams import merge_queue_into_iterator @@ -609,6 +607,7 @@ async def get_agentlet[AgentOutputType]( ) -> PydanticAgent[TDeps, AgentOutputType]: """Create pydantic-ai agent from current state.""" from agentpool.agents.native_agent.tool_wrapping import wrap_tool + from agentpool.utils.context_wrapping import wrap_instruction tools = await self.tools.get_tools(state="enabled") final_type = to_type(output_type) if output_type not in [None, str] else self._output_type @@ -617,50 +616,89 @@ async def get_agentlet[AgentOutputType]( model_, _settings = self._resolve_model_string(actual_model) else: model_ = actual_model - agent = PydanticAgent( + + context_for_tools = self.get_context(input_provider=input_provider) + + # Collect pydantic_ai.tools.Tool instances using Tool.to_pydantic_ai() + pydantic_ai_tools = [] + for tool in tools: + wrapped = wrap_tool(tool, context_for_tools, hooks=self._hook_manager) + pydantic_ai_tool = tool.to_pydantic_ai(function_override=wrapped) + pydantic_ai_tools.append(pydantic_ai_tool) + + # Collect and wrap instructions from all resource providers + all_instructions: list[Any] = [] + + # Start with formatted system prompt as a static instruction + if self._formatted_system_prompt: + all_instructions.append(self._formatted_system_prompt) + + # Collect instructions from all providers + for provider in self.tools.providers: + try: + provider_instructions = await provider.get_instructions() + # Wrap each instruction for pydantic-ai compatibility + for instruction_fn in provider_instructions: + try: + wrapped_instruction = wrap_instruction(instruction_fn, fallback="") + all_instructions.append(wrapped_instruction) + except Exception: + # Wrap failure - log and skip this instruction + logger.exception( + "Failed to wrap instruction, skipping", + provider=provider.name, + instruction=instruction_fn, + ) + continue + except Exception as e: + # Provider failure - log and continue + logger.exception( + "Failed to get instructions from provider", + provider=provider.name, + error=str(e), + ) + continue + + return PydanticAgent( # type: ignore[misc] name=self.name, model=model_, model_settings=self.model_settings, - instructions=self._formatted_system_prompt, + instructions=all_instructions, retries=self._retries, end_strategy=self._end_strategy, output_retries=self._output_retries, - deps_type=self.deps_type or NoneType, - output_type=final_type, + deps_type=self.deps_type or NoneType, # type: ignore[arg-type] + output_type=final_type, # type: ignore[arg-type] + tools=pydantic_ai_tools, builtin_tools=self._builtin_tools, ) - context_for_tools = self.get_context(input_provider=input_provider) - - for tool in tools: - wrapped = wrap_tool(tool, context_for_tools, hooks=self._hook_manager) - - prepare_fn = None - if tool.schema_override: - - def create_prepare( - t: Tool, - ) -> Callable[[RunContext[Any], ToolDefinition], Awaitable[ToolDefinition | None]]: - async def prepare_schema( - ctx: RunContext[Any], tool_def: ToolDefinition - ) -> ToolDefinition | None: - if not t.schema_override: - return None - return ToolDefinition( - name=t.schema_override.get("name") or t.name, - description=t.schema_override.get("description") or t.description, - parameters_json_schema=t.schema_override.get("parameters"), - ) - - return prepare_schema + async def _process_node_stream( + self, + node_stream: AsyncIterator[Any], + *, + file_tracker: FileTracker, + pending_tcs: dict[str, BaseToolCallPart], + message_id: str, + ) -> AsyncIterator[RichAgentStreamEvent[OutputDataT]]: + """Process events from a node stream (ModelRequest or CallTools). - prepare_fn = create_prepare(tool) + Args: + node_stream: Stream of events from the node + file_tracker: Tracker for file operations + pending_tcs: Dictionary of pending tool calls + message_id: Current message ID - if get_argument_key(wrapped, RunContext): - agent.tool(prepare=prepare_fn)(wrapped) - else: - agent.tool_plain(prepare=prepare_fn)(wrapped) - return agent # type: ignore[return-value] + Yields: + Processed stream events + """ + async with merge_queue_into_iterator(node_stream, self._event_queue) as merged: + async for event in file_tracker(merged): + if self._cancelled: + break + yield event + if combined := process_tool_event(self.name, event, pending_tcs, message_id): + yield combined async def _stream_events( self, diff --git a/src/agentpool/docs/gen_examples.py b/src/agentpool/docs/gen_examples.py index 0e04bcb12..61515f262 100644 --- a/src/agentpool/docs/gen_examples.py +++ b/src/agentpool/docs/gen_examples.py @@ -12,6 +12,7 @@ DocStyle = Literal["simple", "full"] EXAMPLES_DIR = Path("src/agentpool_docs/examples") + def create_example_doc(name: str, *, style: DocStyle = "full") -> mk.MkContainer: """Create documentation for an example file. diff --git a/src/agentpool/docs/utils.py b/src/agentpool/docs/utils.py index bfb25a309..2fe5baad4 100644 --- a/src/agentpool/docs/utils.py +++ b/src/agentpool/docs/utils.py @@ -3,9 +3,9 @@ from __future__ import annotations import asyncio -import types from dataclasses import dataclass from pathlib import Path +import types from typing import TYPE_CHECKING, Annotated, Any, Self, Union, get_args, get_origin @@ -253,7 +253,7 @@ def _strip_docstring_sections(description: str) -> str: # Check if we're in a section (indented content after section header) if in_section: # If line is empty or still indented, skip it - if not stripped or line.startswith(" ") or line.startswith("\t"): + if not stripped or line.startswith((" ", "\t")): continue # Non-indented non-empty line means new content in_section = False @@ -352,8 +352,7 @@ def generate_tool_docs(toolset: ResourceProvider) -> str: lines = [f"## {toolset.name.replace('_', ' ').title()} Tools", ""] - for tool in tools: - lines.append(tool_to_markdown(tool)) + lines.extend(tool_to_markdown(tool) for tool in tools) return "\n".join(lines) diff --git a/src/agentpool/prompts/instructions.py b/src/agentpool/prompts/instructions.py new file mode 100644 index 000000000..9b8716bd6 --- /dev/null +++ b/src/agentpool/prompts/instructions.py @@ -0,0 +1,101 @@ +"""Instruction function types and protocols for dynamic prompt generation. + +This module defines the type system for instruction functions that can be used +to generate prompts dynamically based on runtime context. + +Instruction functions can be: +- Simple: No context parameters +- AgentContext: Takes only AgentContext +- RunContext: Takes only RunContext (from pydantic-ai) +- Both: Takes both AgentContext and RunContext + +The InstructionFunc union type accepts any of these variants, allowing +flexible prompt generation based on what context is available. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable + + +__all__ = [ + "AgentContextInstruction", + "BothContextsInstruction", + "InstructionFunc", + "RunContextInstruction", + "SimpleInstruction", +] + + +if TYPE_CHECKING: + from collections.abc import Awaitable + + from pydantic_ai import RunContext + + from agentpool.agents.context import AgentContext + + +# Protocol definitions for type safety +@runtime_checkable +class SimpleInstruction(Protocol): + """Instruction function with no context. + + Functions matching this protocol take no parameters and return + either a string directly or an awaitable string. + """ + + def __call__(self) -> str | Awaitable[str]: ... + + +@runtime_checkable +class AgentContextInstruction(Protocol): + """Instruction function with AgentContext only. + + Functions matching this protocol receive AgentContext, which provides + access to agent-specific runtime information like the current tool, + model name, conversation history, and filesystem access. + + Useful when you need access to agent-level context but don't need + the PydanticAI run context. + """ + + def __call__(self, ctx: AgentContext[Any]) -> str | Awaitable[str]: ... + + +@runtime_checkable +class RunContextInstruction(Protocol): + """Instruction function with RunContext only. + + Functions matching this protocol receive RunContext from PydanticAI, + which provides access to dependencies and other PydanticAI-specific + runtime information. + + Useful when you need access to PydanticAI's dependency injection system + but don't need AgentPool's agent context. + """ + + def __call__(self, ctx: RunContext[Any]) -> str | Awaitable[str]: ... + + +@runtime_checkable +class BothContextsInstruction(Protocol): + """Instruction function with both AgentContext and RunContext. + + Functions matching this protocol receive both context objects, providing + maximum flexibility for prompt generation. + + Use this when you need access to both AgentPool's agent context and + PydanticAI's run context simultaneously. + """ + + def __call__( + self, + agent_ctx: AgentContext[Any], + run_ctx: RunContext[Any], + ) -> str | Awaitable[str]: ... + + +# Union type for all instruction function variants +InstructionFunc = ( + SimpleInstruction | AgentContextInstruction | RunContextInstruction | BothContextsInstruction +) diff --git a/src/agentpool/resource_providers/base.py b/src/agentpool/resource_providers/base.py index 54f81f375..bde0ef5dc 100644 --- a/src/agentpool/resource_providers/base.py +++ b/src/agentpool/resource_providers/base.py @@ -19,6 +19,7 @@ from pydantic_ai import ModelRequestPart from schemez import OpenAIFunctionDefinition + from agentpool.prompts.instructions import InstructionFunc from agentpool.prompts.prompts import BasePrompt from agentpool.resource_providers.resource_info import ResourceInfo from agentpool.skills.skill import Skill @@ -134,6 +135,10 @@ async def get_skills(self) -> list[Skill]: """Get available skills. Override to provide skills.""" return [] + async def get_instructions(self) -> list[InstructionFunc]: + """Get available instruction functions. Override to provide instructions.""" + return [] + async def get_skill_instructions(self, skill_name: str) -> str: """Get full instructions for a specific skill. diff --git a/src/agentpool/resource_providers/instruction_provider.py b/src/agentpool/resource_providers/instruction_provider.py new file mode 100644 index 000000000..f0a3cc6b4 --- /dev/null +++ b/src/agentpool/resource_providers/instruction_provider.py @@ -0,0 +1,104 @@ +"""Instruction provider wrapper for config-based dynamic instructions.""" + +from __future__ import annotations + + +__all__ = ["InstructionProvider"] + +from typing import TYPE_CHECKING, Any, Literal + +from agentpool.log import get_logger +from agentpool.resource_providers import ResourceProvider + + +if TYPE_CHECKING: + from agentpool.prompts.instructions import InstructionFunc + from agentpool_config.instructions import ProviderInstructionConfig + +logger = get_logger(__name__) + + +class InstructionProvider(ResourceProvider): + """Provider wrapper for ProviderInstructionConfig. + + This provider resolves instruction functions from either: + 1. A reference to an existing provider (ref) + 2. An import path to instantiate a provider (import_path) + + When instructions are requested, it delegates to the referenced + provider's get_instructions() method. + """ + + kind: Literal["custom"] = "custom" + + def __init__( + self, + config: ProviderInstructionConfig, + toolsets: list[ResourceProvider] | None = None, + ) -> None: + """Initialize instruction provider. + + Args: + config: The ProviderInstructionConfig to wrap + toolsets: List of existing toolsets to search for ref resolution + """ + super().__init__(name=f"instruction:{config.ref or config.import_path}") + self.config = config + self.toolsets = toolsets or [] + + async def get_tools(self) -> list[Any]: + """Return empty - this is instructions-only.""" + return [] + + async def get_instructions(self) -> list[InstructionFunc]: + """Resolve and return instruction functions. + + For ref: Find the referenced provider in toolsets and delegate. + For import_path: Instantiate the provider and delegate. + """ + from agentpool.utils.importing import import_callable + + if self.config.ref: + # Find referenced provider in toolsets by name + for provider in self.toolsets: + if provider.name == self.config.ref and isinstance(provider, ResourceProvider): + logger.info( + "Delegating to referenced provider", + ref=self.config.ref, + provider=provider.__class__.__name__, + ) + return await provider.get_instructions() + logger.warning( + "Referenced provider not found in toolsets", + ref=self.config.ref, + available_providers=[p.name for p in self.toolsets], + ) + return [] + + if self.config.import_path: + # Instantiate provider from import path + instructions: list[InstructionFunc] = [] + try: + provider_cls = import_callable(self.config.import_path) + provider_instance = provider_cls(**self.config.kw_args) + if isinstance(provider_instance, ResourceProvider): + logger.info( + "Instantiating provider from import path", + import_path=self.config.import_path, + provider=provider_cls.__name__, + ) + instructions = await provider_instance.get_instructions() + else: + logger.warning( + "Instantiated provider does not implement get_instructions", + import_path=self.config.import_path, + provider=provider_cls.__name__, + ) + except (ImportError, TypeError, AttributeError): + logger.exception( + "Failed to instantiate provider from import path", + import_path=self.config.import_path, + ) + return instructions + + return [] diff --git a/src/agentpool/tools/base.py b/src/agentpool/tools/base.py index 1fa24a18c..89d26f201 100644 --- a/src/agentpool/tools/base.py +++ b/src/agentpool/tools/base.py @@ -116,11 +116,13 @@ def get_callable(self) -> Callable[..., TOutputType | Awaitable[TOutputType]]: """Get the callable for this tool. Subclasses must implement.""" ... - def to_pydantic_ai(self) -> PydanticAiTool: + def to_pydantic_ai( + self, function_override: Callable[..., TOutputType | Awaitable[TOutputType]] | None = None + ) -> PydanticAiTool: """Convert tool to Pydantic AI tool.""" metadata = {**self.metadata, "agent_name": self.agent_name, "category": self.category} return PydanticAiTool( - function=self.get_callable(), + function=function_override if function_override is not None else self.get_callable(), name=self.name, description=self.description, requires_approval=self.requires_confirmation, diff --git a/src/agentpool/utils/context_wrapping.py b/src/agentpool/utils/context_wrapping.py new file mode 100644 index 000000000..b9481c926 --- /dev/null +++ b/src/agentpool/utils/context_wrapping.py @@ -0,0 +1,123 @@ +"""Context wrapping utilities for instruction functions. + +This module provides utilities to wrap instruction functions with appropriate +context injection for pydantic-ai compatibility. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from agentpool.log import get_logger + + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + + from pydantic_ai import RunContext + + from agentpool.prompts.instructions import InstructionFunc + +from agentpool.utils.inspection import ( + execute, + get_argument_key, + get_fn_name, + get_fn_qualname, +) + + +logger = get_logger(__name__) + + +def wrap_instruction( + fn: InstructionFunc, + *, + fallback: str = "", +) -> Callable[[RunContext[Any]], Awaitable[str]]: + """Wrap an instruction function for pydantic-ai compatibility. + + This utility adapts instruction functions to pydantic-ai's expected + signature: (RunContext) -> str. It automatically detects and injects + appropriate context(s) based on function signature. + + Supports four patterns: + 1. No context: () -> str + 2. AgentContext only: (AgentContext) -> str + 3. RunContext only: (RunContext) -> str + 4. Both contexts: (AgentContext, RunContext) -> str + + Args: + fn: The instruction function to wrap + fallback: Fallback string if execution fails + + Returns: + Wrapped async function: (RunContext) -> str + + Examples: + No context: + def simple() -> str: + return "Be helpful" + + wrapped = wrap_instruction(simple) + + AgentContext only: + async def with_agent(ctx: AgentContext) -> str: + return f"User: {ctx.deps.user_name}" + + wrapped = wrap_instruction(with_agent) + + Both contexts: + async def with_both(agent_ctx: AgentContext, run_ctx: RunContext) -> str: + return f"User {agent_ctx.deps.name} using {run_ctx.model.model_name}" + + wrapped = wrap_instruction(with_both) + + Accessing AgentContext from RunContext: + Note: RunContext.deps is AgentContext + + async def from_run_context(ctx: RunContext) -> str: + agent_ctx: AgentContext = ctx.deps # Access AgentContext via deps + return f"User: {agent_ctx.data.user_name}" + + wrapped = wrap_instruction(from_run_context) + """ + from pydantic_ai import RunContext + + from agentpool.agents.context import AgentContext + + # Detect which contexts function expects + agent_ctx_key = get_argument_key(fn, AgentContext) + run_ctx_key = get_argument_key(fn, RunContext) + fn_name = get_fn_name(fn) + + async def wrapper(run_ctx: RunContext[Any]) -> str: + """Wrapped function for pydantic-ai.""" + try: + kwargs: dict[str, Any] = {} + + # Inject AgentContext if expected + if agent_ctx_key: + kwargs[agent_ctx_key] = run_ctx.deps + + # Inject RunContext if expected + if run_ctx_key: + kwargs[run_ctx_key] = run_ctx + + # Execute with detected context + if kwargs: + return await execute(fn, **kwargs) + return await execute(fn) + + except Exception: + # Log error and return fallback + logger.exception( + "Instruction execution failed", + function=fn_name, + ) + return fallback + + # Preserve function metadata for debugging + wrapper.__name__ = fn_name + wrapper.__qualname__ = get_fn_qualname(fn) + + return wrapper diff --git a/src/agentpool_config/instructions.py b/src/agentpool_config/instructions.py new file mode 100644 index 000000000..eae8431fa --- /dev/null +++ b/src/agentpool_config/instructions.py @@ -0,0 +1,37 @@ +"""Configuration models for instruction providers.""" + +from __future__ import annotations + +from typing import Any, Literal + +from pydantic import BaseModel, Field, model_validator + + +class ProviderInstructionConfig(BaseModel): + """Configuration for provider-based dynamic instructions.""" + + type: Literal["provider"] = Field("provider", init=False) + + ref: str | None = Field( + default=None, + description="Name of existing toolset provider to reference.", + ) + + import_path: str | None = Field( + default=None, + description="Python import path to ResourceProvider class.", + ) + + kw_args: dict[str, Any] = Field( + default_factory=dict, + description="Keyword arguments to pass to the provider constructor.", + ) + + @model_validator(mode="after") + def validate_ref_or_import_path(self) -> ProviderInstructionConfig: + """Validate that exactly one of ref or import_path is provided.""" + if self.ref is None and self.import_path is None: + raise ValueError("Either 'ref' or 'import_path' must be provided") + if self.ref is not None and self.import_path is not None: + raise ValueError("Only one of 'ref' or 'import_path' can be provided") + return self diff --git a/tests/agents/native_agent/__init__.py b/tests/agents/native_agent/__init__.py new file mode 100644 index 000000000..c188448d9 --- /dev/null +++ b/tests/agents/native_agent/__init__.py @@ -0,0 +1 @@ +"""Test agent instruction tests package.""" diff --git a/tests/agents/native_agent/test_agent_instructions.py b/tests/agents/native_agent/test_agent_instructions.py new file mode 100644 index 000000000..cec5d8150 --- /dev/null +++ b/tests/agents/native_agent/test_agent_instructions.py @@ -0,0 +1,267 @@ +"""Test provider instruction integration into NativeAgent.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from pydantic_ai import Agent as PydanticAgent +import pytest + +from agentpool.agents.native_agent import Agent +from agentpool.resource_providers.base import ResourceProvider + + +if TYPE_CHECKING: + from agentpool.agents.context import AgentContext + from agentpool.prompts.instructions import InstructionFunc + + +class SimpleInstructionProvider(ResourceProvider): + """Simple provider that returns static instructions.""" + + def __init__(self) -> None: + super().__init__("simple_provider") + self.kind = "base" + + async def get_instructions(self) -> list[InstructionFunc]: + """Return a simple instruction function.""" + + def simple_instruction() -> str: + return "Be helpful and concise" + + return [simple_instruction] + + +class AgentContextInstructionProvider(ResourceProvider): + """Provider that returns AgentContext-aware instruction.""" + + def __init__(self) -> None: + super().__init__("agent_context_provider") + self.kind = "base" + + async def get_instructions(self) -> list[InstructionFunc]: + """Return instruction that expects AgentContext.""" + + async def with_agent_context(ctx: AgentContext[Any]) -> str: + return f"Agent model: {ctx.model_name}" + + return [with_agent_context] + + +class RunContextInstructionProvider(ResourceProvider): + """Provider that returns RunContext-aware instruction.""" + + def __init__(self) -> None: + super().__init__("run_context_provider") + self.kind = "base" + + async def get_instructions(self) -> list[InstructionFunc]: + """Return instruction that expects RunContext.""" + + async def with_run_context(ctx: Any) -> str: + return "Model: gpt-4o-mini" + + return [with_run_context] + + +class EmptyInstructionProvider(ResourceProvider): + """Provider that returns no instructions.""" + + def __init__(self) -> None: + super().__init__("empty_provider") + self.kind = "base" + + async def get_instructions(self) -> list[InstructionFunc]: + """Return empty list.""" + return [] + + +class FailingInstructionProvider(ResourceProvider): + """Provider that fails to provide instructions.""" + + def __init__(self) -> None: + super().__init__("failing_provider") + self.kind = "base" + + async def get_instructions(self) -> list[InstructionFunc]: + """Always raises RuntimeError.""" + msg = "Failed to get instructions" + raise RuntimeError(msg) + + +@pytest.fixture +async def agent_with_instruction_providers(): + """Create an agent with instruction providers.""" + from agentpool.agents.native_agent import Agent + + # Create providers + provider1 = SimpleInstructionProvider() + provider2 = AgentContextInstructionProvider() + provider3 = RunContextInstructionProvider() + + # Create agent + agent = Agent( + name="test_agent", + model="openai:gpt-4o-mini", + system_prompt="You are an AI assistant.", + ) + + # Add providers via tool manager + agent.tools.add_provider(provider1) + agent.tools.add_provider(provider2) + agent.tools.add_provider(provider3) + + return agent + + +class TestNativeAgentInstructions: + """Test NativeAgent integration with provider instructions.""" + + async def test_agentlet_collects_instructions_from_providers( + self, agent_with_instruction_providers: Agent + ): + """Test that get_agentlet collects instructions from all providers.""" + agentlet: PydanticAgent[Any, Any] = await agent_with_instruction_providers.get_agentlet( + None, None, None + ) + + # Verify that agentlet was created + assert isinstance(agentlet, PydanticAgent) + assert agentlet.name == "test_agent" + + async def test_formatted_system_prompt_includes_static_prompt( + self, agent_with_instruction_providers: Agent + ): + """Test that formatted system prompt includes static system prompt.""" + # Initialize agent to format system prompt + async with agent_with_instruction_providers: + # Access the formatted system prompt + assert agent_with_instruction_providers._formatted_system_prompt is not None + assert ( + "You are an AI assistant." + in agent_with_instruction_providers._formatted_system_prompt + ) + + async def test_instructions_are_collected_and_wrapped( + self, agent_with_instruction_providers: Agent + ): + """Test that instructions from providers are collected and wrapped.""" + # Get agentlet which should include wrapped instructions + async with agent_with_instruction_providers: + agentlet: PydanticAgent[Any, Any] = await agent_with_instruction_providers.get_agentlet( + None, None, None + ) + + # The instructions should be in agentlet's instructions + # They should be wrapped to be RunContext -> str + assert agentlet.instructions is not None + + async def test_provider_instructions_reactive(self, agent_with_instruction_providers: Agent): + """Test that provider instructions are called on each run.""" + # Create agent + async with agent_with_instruction_providers as agent: + # Run the agent (instructions should be evaluated) + # Note: This would require a model key, so we'll just test setup + agentlet: PydanticAgent[Any, Any] = await agent.get_agentlet(None, None, None) + + # Verify agentlet was created with instructions + assert agentlet is not None + assert agentlet.instructions is not None + + async def test_no_providers_uses_only_static_prompt(self): + """Test that agent works normally with no instruction providers.""" + from agentpool.agents.native_agent import Agent + + agent = Agent( + name="simple_agent", + model="openai:gpt-4o-mini", + system_prompt="You are a simple assistant.", + ) + + async with agent: + agentlet: PydanticAgent[Any, Any] = await agent.get_agentlet(None, None, None) + + # Should work with just static system prompt + assert isinstance(agentlet, PydanticAgent) + assert agentlet.instructions is not None + + async def test_provider_returning_empty_instructions(self): + """Test that providers returning empty list are handled.""" + from agentpool.agents.native_agent import Agent + + agent = Agent( + name="empty_provider_agent", + model="openai:gpt-4o-mini", + system_prompt="You are an assistant.", + ) + + agent.tools.add_provider(EmptyInstructionProvider()) + + async with agent: + # Should not fail with empty instruction list + agentlet: PydanticAgent[Any, Any] = await agent.get_agentlet(None, None, None) + assert isinstance(agentlet, PydanticAgent) + + async def test_provider_get_instructions_error_handling(self): + """Test that errors in provider.get_instructions are handled gracefully.""" + agent = Agent( + name="failing_provider_agent", + model="openai:gpt-4o-mini", + system_prompt="You are an assistant.", + ) + + agent.tools.add_provider(FailingInstructionProvider()) + + async with agent: + # Should handle error gracefully and still create agentlet + # Implementation should log error and continue + agentlet: PydanticAgent[Any, Any] = await agent.get_agentlet(None, None, None) + assert isinstance(agentlet, PydanticAgent) + + async def test_from_config_with_provider_instruction_ref(self): + """Test from_config with ProviderInstructionConfig using ref.""" + from agentpool.agents.native_agent import Agent + from agentpool.models.agents import NativeAgentConfig + from agentpool_config.instructions import ProviderInstructionConfig + + # Create a simple provider with get_instructions + class SimpleRefProvider(ResourceProvider): + def __init__(self) -> None: + super().__init__("simple_ref_provider") + self.kind = "base" + + async def get_tools(self) -> list[Any]: + return [] + + async def get_instructions(self) -> list[InstructionFunc]: + async def simple_inst() -> str: + return "Dynamic instruction from ref provider" + + return [simple_inst] + + # Create config with ProviderInstructionConfig referencing to provider + # Note: In actual usage, toolsets would come from manifest or toolset config + # For this test, we'll add provider via tool manager + + config = NativeAgentConfig( + name="test_agent_with_ref", + model="openai:gpt-4o-mini", + system_prompt=["Be helpful.", ProviderInstructionConfig(ref="simple_ref_provider")], + ) + + # Create agent from config + agent = Agent.from_config(config) + + # Manually add the referenced provider to the tool manager + # This simulates how it would come from toolsets in real usage + provider = SimpleRefProvider() + agent.tools.add_provider(provider) + + async with agent: + # Verify agentlet can be created + agentlet: PydanticAgent[Any, Any] = await agent.get_agentlet(None, None, None) + assert isinstance(agentlet, PydanticAgent) + + # Verify that a provider is in the tools.providers list + provider_names = [p.name for p in agent.tools.providers] + assert "instruction:simple_ref_provider" in provider_names diff --git a/tests/config/__init__.py b/tests/config/__init__.py new file mode 100644 index 000000000..1eb2ab995 --- /dev/null +++ b/tests/config/__init__.py @@ -0,0 +1 @@ +"""Tests for config models.""" diff --git a/tests/config/test_instructions_config.py b/tests/config/test_instructions_config.py new file mode 100644 index 000000000..ab9e3a30e --- /dev/null +++ b/tests/config/test_instructions_config.py @@ -0,0 +1,66 @@ +"""Test ProviderInstructionConfig model.""" + +from __future__ import annotations + +import pytest + + +class TestProviderInstructionConfig: + """Test ProviderInstructionConfig model.""" + + async def test_valid_config_with_ref(self): + """Test valid config with ref field only.""" + from agentpool_config.instructions import ProviderInstructionConfig + + config = ProviderInstructionConfig( + ref="my_provider", + ) + + assert config.type == "provider" + assert config.ref == "my_provider" + assert config.import_path is None + assert config.kw_args == {} + + async def test_valid_config_with_import_path(self): + """Test valid config with import_path field only.""" + from agentpool_config.instructions import ProviderInstructionConfig + + config = ProviderInstructionConfig( + import_path="my.module.Provider", + ) + + assert config.type == "provider" + assert config.import_path == "my.module.Provider" + assert config.ref is None + assert config.kw_args == {} + + async def test_valid_config_with_import_path_and_kwargs(self): + """Test valid config with import_path and kwargs.""" + from agentpool_config.instructions import ProviderInstructionConfig + + config = ProviderInstructionConfig( + import_path="my.module.Provider", + kw_args={"key": "value"}, + ) + + assert config.type == "provider" + assert config.import_path == "my.module.Provider" + assert config.ref is None + assert config.kw_args == {"key": "value"} + + async def test_invalid_config_both_ref_and_import_path(self): + """Test that config with both ref and import_path raises error.""" + from agentpool_config.instructions import ProviderInstructionConfig + + with pytest.raises(ValueError, match="Only one of 'ref' or 'import_path' can be provided"): + ProviderInstructionConfig( + ref="my_provider", + import_path="my.module.Provider", + ) + + async def test_invalid_config_neither_ref_nor_import_path(self): + """Test that config with neither ref nor import_path raises error.""" + from agentpool_config.instructions import ProviderInstructionConfig + + with pytest.raises(ValueError, match="Either 'ref' or 'import_path' must be provided"): + ProviderInstructionConfig() diff --git a/tests/prompts/test_instructions_types.py b/tests/prompts/test_instructions_types.py new file mode 100644 index 000000000..113a1be58 --- /dev/null +++ b/tests/prompts/test_instructions_types.py @@ -0,0 +1,213 @@ +"""Test instruction function types and protocols.""" + +from __future__ import annotations + +import asyncio +from collections.abc import Awaitable # noqa: TC003 +from typing import TYPE_CHECKING, Any + +import pytest + +# Test that all types can be imported (will fail if they don't exist) +from agentpool.prompts.instructions import ( # noqa: TC001 + AgentContextInstruction, + BothContextsInstruction, + InstructionFunc, + RunContextInstruction, + SimpleInstruction, +) + + +if TYPE_CHECKING: + from pydantic_ai import RunContext + + from agentpool.agents.context import AgentContext + + +class TestSimpleInstruction: + """Test SimpleInstruction protocol.""" + + def test_sync_simple_instruction(self): + """Simple synchronous instruction function.""" + + def simple_prompt() -> str: + return "Hello, world!" + + # The function should match the SimpleInstruction protocol + result: str | Awaitable[str] = simple_prompt() + assert result == "Hello, world!" + + async def test_async_simple_instruction(self): + """Simple async instruction function.""" + + async def simple_prompt() -> str: + await asyncio.sleep(0) + return "Hello, async world!" + + # The function should match the SimpleInstruction protocol + async def use_instruction(instruction: SimpleInstruction) -> str: + result = instruction() + if isinstance(result, str): + return result + return await result + + result = await use_instruction(simple_prompt) + assert result == "Hello, async world!" + + +class TestAgentContextInstruction: + """Test AgentContextInstruction protocol.""" + + def test_sync_agent_context_instruction(self): + """Instruction function with AgentContext only.""" + + def context_prompt(ctx: AgentContext[Any]) -> str: + return f"Context: {ctx.tool_name or 'none'}" + + # Create a mock context (we'll use None for this test) + # The function should match the AgentContextInstruction protocol + def use_instruction(instruction: AgentContextInstruction) -> str: + + # Create minimal agent for testing + with pytest.MonkeyPatch().context(): + pass # We can't easily create a full AgentContext without setup + return "placeholder" + + async def test_async_agent_context_instruction(self): + """Async instruction function with AgentContext only.""" + + async def context_prompt(ctx: AgentContext[Any]) -> str: + await asyncio.sleep(0) + return f"Context: {ctx.tool_name or 'none'}" + + # The function should match the AgentContextInstruction protocol + async def use_instruction(instruction: AgentContextInstruction) -> str: + + # Create minimal agent for testing + return "placeholder" + + +class TestRunContextInstruction: + """Test RunContextInstruction protocol.""" + + def test_sync_run_context_instruction(self): + """Instruction function with RunContext only.""" + + def context_prompt(ctx: RunContext[Any]) -> str: + return f"Dep: {ctx.deps or 'none'}" + + # The function should match the RunContextInstruction protocol + def use_instruction(instruction: RunContextInstruction) -> str: + return "placeholder" + + +class TestBothContextsInstruction: + """Test BothContextsInstruction protocol.""" + + def test_sync_both_contexts_instruction(self): + """Instruction function with both AgentContext and RunContext.""" + + def dual_prompt( + agent_ctx: AgentContext[Any], + run_ctx: RunContext[Any], + ) -> str: + return f"Agent: {agent_ctx.tool_name or 'none'}, Run: {run_ctx.deps or 'none'}" + + # The function should match the BothContextsInstruction protocol + def use_instruction(instruction: BothContextsInstruction) -> str: + return "placeholder" + + +class TestInstructionFuncUnion: + """Test InstructionFunc union type.""" + + def test_simple_instruction_in_union(self): + """SimpleInstruction should be assignable to InstructionFunc.""" + + def simple() -> str: + return "test" + + func: InstructionFunc = simple + # Just verify type compatibility + assert callable(func) + + def test_agent_context_instruction_in_union(self): + """AgentContextInstruction should be assignable to InstructionFunc.""" + + def with_agent_context(ctx: AgentContext[Any]) -> str: + return "test" + + func: InstructionFunc = with_agent_context + # Just verify type compatibility + assert callable(func) + + def test_run_context_instruction_in_union(self): + """RunContextInstruction should be assignable to InstructionFunc.""" + + def with_run_context(ctx: RunContext[Any]) -> str: + return "test" + + func: InstructionFunc = with_run_context + # Just verify type compatibility + assert callable(func) + + def test_both_contexts_instruction_in_union(self): + """BothContextsInstruction should be assignable to InstructionFunc.""" + + def with_both_contexts( + agent_ctx: AgentContext[Any], + run_ctx: RunContext[Any], + ) -> str: + return "test" + + func: InstructionFunc = with_both_contexts + # Just verify type compatibility + assert callable(func) + + +class TestRuntimeCheckable: + """Test @runtime_checkable decorator on protocols.""" + + def test_simple_instruction_isinstance(self): + """SimpleInstruction should support isinstance checks.""" + + def simple() -> str: + return "test" + + # @runtime_checkable enables isinstance checks + assert isinstance(simple, SimpleInstruction) + + def test_agent_context_instruction_isinstance(self): + """AgentContextInstruction should support isinstance checks.""" + + def with_ctx(ctx: AgentContext[Any]) -> str: + return "test" + + assert isinstance(with_ctx, AgentContextInstruction) + + def test_run_context_instruction_isinstance(self): + """RunContextInstruction should support isinstance checks.""" + + def with_ctx(ctx: RunContext[Any]) -> str: + return "test" + + assert isinstance(with_ctx, RunContextInstruction) + + def test_both_contexts_instruction_isinstance(self): + """BothContextsInstruction should support isinstance checks.""" + + def with_both( + agent_ctx: AgentContext[Any], + run_ctx: RunContext[Any], + ) -> str: + return "test" + + assert isinstance(with_both, BothContextsInstruction) + + def test_async_simple_instruction_isinstance(self): + """Async simple instruction should support isinstance checks.""" + + async def async_simple() -> str: + return "test" + + assert isinstance(async_simple, SimpleInstruction) diff --git a/tests/resource_providers/test_base.py b/tests/resource_providers/test_base.py new file mode 100644 index 000000000..6a45857c2 --- /dev/null +++ b/tests/resource_providers/test_base.py @@ -0,0 +1,49 @@ +"""Test ResourceProvider base class.""" + +from __future__ import annotations + + +class TestResourceProviderGetInstructions: + """Test get_instructions method on ResourceProvider base class.""" + + async def test_get_instructions_method_exists(self): + """Test that get_instructions method exists on ResourceProvider.""" + from agentpool.resource_providers.base import ResourceProvider + + assert hasattr(ResourceProvider, "get_instructions") + assert callable(ResourceProvider.get_instructions) + + async def test_get_instructions_is_async(self): + """Test that get_instructions is an async method.""" + import inspect + + from agentpool.resource_providers.base import ResourceProvider + + assert inspect.iscoroutinefunction(ResourceProvider.get_instructions) + + async def test_get_instructions_returns_list(self): + """Test that default get_instructions returns an empty list.""" + from agentpool.resource_providers.base import ResourceProvider + + class TestProvider(ResourceProvider): + pass + + provider = TestProvider(name="test_provider") + result = await provider.get_instructions() + + assert isinstance(result, list) + assert result == [] + + async def test_get_instructions_signature(self): + """Test that get_instructions has correct return type annotation.""" + import inspect + + from agentpool.resource_providers.base import ResourceProvider + + sig = inspect.signature(ResourceProvider.get_instructions) + + # Should be annotated as returning list[InstructionFunc] + # Note: In Python 3.9+, we can use list[...] + # Type checking may show this differently, but runtime check should work + # For now, just verify it has a return annotation + assert sig.return_annotation is not inspect.Signature.empty diff --git a/tests/utils/test_context_wrapping.py b/tests/utils/test_context_wrapping.py new file mode 100644 index 000000000..759a049ef --- /dev/null +++ b/tests/utils/test_context_wrapping.py @@ -0,0 +1,190 @@ +"""Test context wrapping utility for instruction functions.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any +from unittest.mock import MagicMock, Mock + + +if TYPE_CHECKING: + from collections.abc import Awaitable + +from pydantic_ai import RunContext + +from agentpool.agents.context import AgentContext + + +class TestWrapInstruction: + """Test wrap_instruction utility.""" + + def _create_mock_run_context(self, deps: Any = None) -> RunContext[Any]: + """Helper to create a mock RunContext with all required params.""" + mock_model = Mock() + mock_usage = Mock() + mock_messages: list[Any] = [] + return RunContext( + deps=deps, + model=mock_model, + usage=mock_usage, + messages=mock_messages, + ) + + async def test_sync_function_no_context(self): + """Test wrapping a sync function with no context.""" + + def simple() -> str: + return "Be helpful" + + from agentpool.utils.context_wrapping import wrap_instruction + + wrapped = wrap_instruction(simple) + run_ctx = self._create_mock_run_context() + result: Awaitable[str] = wrapped(run_ctx) + assert await result == "Be helpful" + + async def test_async_function_no_context(self): + """Test wrapping an async function with no context.""" + + async def simple_async() -> str: + return "Be helpful and async" + + from agentpool.utils.context_wrapping import wrap_instruction + + wrapped = wrap_instruction(simple_async) + run_ctx = self._create_mock_run_context() + result: Awaitable[str] = wrapped(run_ctx) + assert await result == "Be helpful and async" + + async def test_function_with_agent_context(self): + """Test wrapping a function that expects AgentContext.""" + + def with_agent(ctx: AgentContext) -> str: + return f"Tool: {ctx.tool_name or 'none'}" + + from agentpool.utils.context_wrapping import wrap_instruction + + wrapped = wrap_instruction(with_agent) + + # Create a mock AgentContext that only needs tool_name attribute + mock_agent_ctx = MagicMock(spec=AgentContext) + mock_agent_ctx.tool_name = None + + run_ctx = self._create_mock_run_context(deps=mock_agent_ctx) + result: Awaitable[str] = wrapped(run_ctx) + assert await result == "Tool: none" + + async def test_async_function_with_agent_context(self): + """Test wrapping an async function that expects AgentContext.""" + + async def with_agent_async(ctx: AgentContext) -> str: + return f"Tool (async): {ctx.tool_name or 'none'}" + + from agentpool.utils.context_wrapping import wrap_instruction + + wrapped = wrap_instruction(with_agent_async) + + # Create a mock AgentContext + mock_agent_ctx = MagicMock(spec=AgentContext) + mock_agent_ctx.tool_name = None + + run_ctx = self._create_mock_run_context(deps=mock_agent_ctx) + result: Awaitable[str] = wrapped(run_ctx) + assert await result == "Tool (async): none" + + async def test_function_with_run_context(self): + """Test wrapping a function that expects RunContext.""" + + def with_run(ctx: RunContext) -> str: + return f"RunContext: {ctx.deps or 'none'}" + + from agentpool.utils.context_wrapping import wrap_instruction + + wrapped = wrap_instruction(with_run) + run_ctx = self._create_mock_run_context() + result: Awaitable[str] = wrapped(run_ctx) + assert await result == "RunContext: none" + + async def test_function_with_both_contexts(self): + """Test wrapping a function that expects both AgentContext and RunContext.""" + + def with_both(agent_ctx: AgentContext, run_ctx: RunContext) -> str: + return f"Agent: {agent_ctx.tool_name or 'none'}, Run: {run_ctx.deps or 'none'}" + + from agentpool.utils.context_wrapping import wrap_instruction + + wrapped = wrap_instruction(with_both) + + # Create a mock AgentContext + mock_agent_ctx = MagicMock(spec=AgentContext) + mock_agent_ctx.tool_name = None + + run_ctx = self._create_mock_run_context(deps=mock_agent_ctx) + result_str = await wrapped(run_ctx) + # The result will contain to actual object representation + assert "Agent: none" in result_str + assert "Run:" in result_str + assert "MagicMock" in result_str + + async def test_error_handling_with_fallback(self): + """Test that errors are caught and fallback is returned.""" + + def error_func() -> str: + raise ValueError("This should be caught") + + from agentpool.utils.context_wrapping import wrap_instruction + + wrapped = wrap_instruction(error_func, fallback="Fallback text") + run_ctx = self._create_mock_run_context() + result: Awaitable[str] = wrapped(run_ctx) + assert await result == "Fallback text" + + async def test_default_fallback_empty_string(self): + """Test that default fallback is empty string.""" + + def error_func() -> str: + raise ValueError("Error") + + from agentpool.utils.context_wrapping import wrap_instruction + + wrapped = wrap_instruction(error_func) # No fallback specified + run_ctx = self._create_mock_run_context() + result: Awaitable[str] = wrapped(run_ctx) + assert await result == "" + + async def test_wrapped_preserves_function_name(self): + """Test that wrapper preserves original function name.""" + + def named_function() -> str: + return "test" + + from agentpool.utils.context_wrapping import wrap_instruction + + wrapped = wrap_instruction(named_function) + + # functools.wraps preserves __name__ + assert wrapped.__name__ == "named_function" + + async def test_instruction_func_union(self): + """Test that InstructionFunc union types are accepted.""" + + def simple() -> str: + return "simple" + + def with_agent(ctx: AgentContext) -> str: + return "agent" + + def with_run(ctx: RunContext) -> str: + return "run" + + def with_both(agent_ctx: AgentContext, run_ctx: RunContext) -> str: + return "both" + + from agentpool.utils.context_wrapping import wrap_instruction + + # All should be accepted as InstructionFunc + for func in [simple, with_agent, with_run, with_both]: + wrapped = wrap_instruction(func) # type: ignore[arg-type] + run_ctx = self._create_mock_run_context() + result: Awaitable[str] = wrapped(run_ctx) + # Just verify it doesn't crash + await result