diff --git a/cognite/client/_api/agents/agents.py b/cognite/client/_api/agents/agents.py index bd177be879..f6eedb9f22 100644 --- a/cognite/client/_api/agents/agents.py +++ b/cognite/client/_api/agents/agents.py @@ -1,11 +1,16 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING, overload +from typing import TYPE_CHECKING, Any, overload from cognite.client._api_client import APIClient from cognite.client.data_classes.agents import Agent, AgentList, AgentUpsert -from cognite.client.data_classes.agents.chat import AgentChatResponse, Message, MessageList +from cognite.client.data_classes.agents.chat import ( + Action, + ActionResult, + AgentChatResponse, + Message, +) from cognite.client.utils._experimental import FeaturePreviewWarning from cognite.client.utils._identifier import IdentifierSequence from cognite.client.utils.useful_types import SequenceNotStr @@ -243,8 +248,9 @@ def list(self) -> AgentList: # The API does not yet support limit or pagination def chat( self, agent_id: str, - messages: Message | Sequence[Message], + messages: Message | ActionResult | Sequence[Message | ActionResult], cursor: str | None = None, + actions: Sequence[Action] | None = None, ) -> AgentChatResponse: """`Chat with an agent. `_ @@ -253,9 +259,10 @@ def chat( Args: agent_id (str): External ID that uniquely identifies the agent. - messages (Message | Sequence[Message]): A list of one or many input messages to the agent. + messages (Message | ActionResult | Sequence[Message | ActionResult]): A list of one or many input messages to the agent. Can include regular messages and action results. cursor (str | None): The cursor to use for continuation of a conversation. Use this to create multi-turn conversations, as the cursor will keep track of the conversation state. + actions (Sequence[Action] | None): A list of client-side actions that can be called by the agent. Returns: AgentChatResponse: The response from the agent. @@ -290,22 +297,60 @@ def chat( ... Message("Once you have found it, find related time series.") ... ] ... ) + + Chat with client-side actions: + + >>> from cognite.client.data_classes.agents import ClientToolAction, ClientToolResult + >>> add_numbers_action = ClientToolAction( + ... name="add", + ... description="Add two numbers together", + ... parameters={ + ... "type": "object", + ... "properties": { + ... "a": {"type": "number", "description": "First number"}, + ... "b": {"type": "number", "description": "Second number"}, + ... }, + ... "required": ["a", "b"] + ... } + ... ) + >>> response = client.agents.chat( + ... agent_id="my_agent", + ... messages=Message("What is 42 plus 58?"), + ... actions=[add_numbers_action] + ... ) + >>> if response.action_calls: + ... for call in response.action_calls: + ... # Execute the action + ... result = call.arguments["a"] + call.arguments["b"] + ... # Send result back + ... response = client.agents.chat( + ... agent_id="my_agent", + ... messages=ClientToolResult( + ... action_id=call.action_id, + ... content=f"The result is {result}" + ... ), + ... cursor=response.cursor, + ... actions=[add_numbers_action] + ... ) """ self._warnings.warn() # Convert single message to list - if isinstance(messages, Message): + if isinstance(messages, (Message, ActionResult)): messages = [messages] # Build request body - body = { + body: dict[str, Any] = { "agentId": agent_id, - "messages": MessageList(messages).dump(camel_case=True), + "messages": [msg.dump(camel_case=True) for msg in messages], } if cursor is not None: body["cursor"] = cursor + if actions is not None: + body["actions"] = [action.dump(camel_case=True) for action in actions] + # Make the API call response = self._post( url_path=self._RESOURCE_PATH + "/chat", diff --git a/cognite/client/data_classes/agents/__init__.py b/cognite/client/data_classes/agents/__init__.py index a9931b6240..2ee4644ecf 100644 --- a/cognite/client/data_classes/agents/__init__.py +++ b/cognite/client/data_classes/agents/__init__.py @@ -21,19 +21,31 @@ ) from cognite.client.data_classes.agents.agents import Agent, AgentList, AgentUpsert, AgentUpsertList from cognite.client.data_classes.agents.chat import ( + Action, + ActionCall, + ActionResult, AgentChatResponse, AgentDataItem, AgentMessage, AgentMessageList, AgentReasoningItem, + ClientToolAction, + ClientToolCall, + ClientToolResult, Message, MessageContent, MessageList, TextContent, + UnknownAction, + UnknownActionCall, + UnknownActionResult, UnknownContent, ) __all__ = [ + "Action", + "ActionCall", + "ActionResult", "Agent", "AgentChatResponse", "AgentDataItem", @@ -49,6 +61,9 @@ "AgentUpsertList", "AskDocumentAgentTool", "AskDocumentAgentToolUpsert", + "ClientToolAction", + "ClientToolCall", + "ClientToolResult", "DataModelInfo", "InstanceSpaces", "Message", @@ -62,6 +77,9 @@ "SummarizeDocumentAgentTool", "SummarizeDocumentAgentToolUpsert", "TextContent", + "UnknownAction", + "UnknownActionCall", + "UnknownActionResult", "UnknownAgentTool", "UnknownAgentToolUpsert", "UnknownContent", diff --git a/cognite/client/data_classes/agents/chat.py b/cognite/client/data_classes/agents/chat.py index 1a8654ae2d..90e3a2262a 100644 --- a/cognite/client/data_classes/agents/chat.py +++ b/cognite/client/data_classes/agents/chat.py @@ -1,10 +1,14 @@ from __future__ import annotations +import json +import warnings from abc import ABC, abstractmethod +from collections.abc import Callable from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, ClassVar, Literal from cognite.client.data_classes._base import CogniteObject, CogniteResource, CogniteResourceList +from cognite.client.utils._function_introspection import function_to_json_schema from cognite.client.utils._text import convert_all_keys_to_camel_case if TYPE_CHECKING: @@ -85,6 +89,307 @@ def _load_content(cls, data: dict[str, Any]) -> UnknownContent: } +@dataclass +class Action(CogniteObject, ABC): + """Base class for all action types that can be provided to an agent.""" + + _type: ClassVar[str] + + @classmethod + def _load(cls, data: dict[str, Any], cognite_client: CogniteClient | None = None) -> Action: + """Dispatch to the correct concrete action class based on `type`.""" + action_type = data.get("type", "") + action_class = _ACTION_CLS_BY_TYPE.get(action_type, UnknownAction) + return action_class._load_action(data, cognite_client) + + @abstractmethod + def dump(self, camel_case: bool = True) -> dict[str, Any]: + """Dump the action to a dictionary.""" + ... + + @classmethod + @abstractmethod + def _load_action(cls, data: dict[str, Any], cognite_client: CogniteClient | None = None) -> Action: + """Create a concrete action instance from raw data.""" + ... + + +@dataclass +class ClientToolAction(Action): + """A client-side tool definition that can be called by the agent. + + Args: + name (str): The name of the client tool to call. + description (str): A description of what the function does. The language model will use this description when selecting the function and interpreting its parameters. + parameters (dict[str, Any]): The parameters the function accepts, described as a JSON Schema object. + """ + + _type: ClassVar[str] = "clientTool" + name: str + description: str + parameters: dict[str, Any] + + def dump(self, camel_case: bool = True) -> dict[str, Any]: + return { + "type": self._type, + "clientTool": { + "name": self.name, + "description": self.description, + "parameters": self.parameters, + }, + } + + @classmethod + def _load_action(cls, data: dict[str, Any], cognite_client: CogniteClient | None = None) -> ClientToolAction: + client_tool = data["clientTool"] + return cls( + name=client_tool["name"], + description=client_tool["description"], + parameters=client_tool["parameters"], + ) + + @classmethod + def from_function(cls, func: Callable[..., Any], name: str | None = None) -> ClientToolAction: + """Create a ClientToolAction from a Python function with type hints and docstring. + + This method introspects a Python function to automatically generate an action definition + that can be used with agents. The function must have type hints for all parameters. + + **Supported Types:** + - Primitives: ``int``, ``float``, ``str``, ``bool`` + - Lists of primitives: ``list[int]``, ``list[float]``, ``list[str]``, ``list[bool]`` + + **Optional Parameters:** + Parameters with a default value of ``None`` are treated as optional and will not be + included in the ``required`` list of the JSON Schema. + + **Docstring Format:** + The function should have a Google-style docstring with: + - A short description (used as the action description) + - An ``Args:`` section with parameter descriptions + + Args: + func (Callable[..., Any]): The Python function to convert to an action. + name (str | None): Optional custom name for the action. If not provided, uses the function name. + + Returns: + ClientToolAction: An action that can be used in agent chat. + + Raises: + TypeError: If a parameter is missing a type annotation or has an unsupported type. + + Examples: + Basic usage with primitives:: + + def add(a: float, b: float) -> float: + \"\"\"Add two numbers. + + Args: + a: The first number. + b: The second number. + + Returns: + The sum of the two numbers. + \"\"\" + return a + b + + action = ClientToolAction.from_function(add) + + With optional parameters:: + + def greet(name: str, title: str = None) -> str: + \"\"\"Greet a person. + + Args: + name: The person's name. + title: Optional title (e.g., "Dr.", "Mr."). + + Returns: + A greeting message. + \"\"\" + if title: + return f"Hello, {title} {name}!" + return f"Hello, {name}!" + + action = ClientToolAction.from_function(greet) + # name is required, title is optional + + With list parameters:: + + def sum_numbers(numbers: list[float]) -> float: + \"\"\"Calculate the sum of a list of numbers. + + Args: + numbers: List of numbers to sum. + + Returns: + The sum of all numbers. + \"\"\" + return sum(numbers) + + action = ClientToolAction.from_function(sum_numbers) + + With custom name:: + + action = ClientToolAction.from_function(add, name="custom_add") + + Use in agent chat:: + + response = client.agents.chat( + agent_id="my_agent", + messages=Message("What is 42 plus 58?"), + actions=[ClientToolAction.from_function(add)] + ) + """ + # Get function name + action_name = name or func.__name__ + + # Generate JSON Schema from function + func_description, parameters_schema = function_to_json_schema(func) + + # Warn if no docstring (description will be function name) + if func_description == func.__name__: + warnings.warn( + f"Function '{func.__name__}' has no docstring, using function name as description", + UserWarning, + stacklevel=2, + ) + + return cls( + name=action_name, + description=func_description, + parameters=parameters_schema, + ) + + +@dataclass +class UnknownAction(Action): + """Unknown action type for forward compatibility. + + Args: + type (str): The action type. + data (dict[str, Any]): The raw action data. + """ + + type: str + data: dict[str, Any] = field(default_factory=dict) + + def dump(self, camel_case: bool = True) -> dict[str, Any]: + result = self.data.copy() + result["type"] = self.type + return result + + @classmethod + def _load_action(cls, data: dict[str, Any], cognite_client: CogniteClient | None = None) -> UnknownAction: + action_type = data.get("type", "") + return cls(data=data, type=action_type) + + +# Build the mapping AFTER concrete classes are defined +_ACTION_CLS_BY_TYPE: dict[str, type[Action]] = { + subclass._type: subclass # type: ignore[type-abstract] + for subclass in Action.__subclasses__() + if hasattr(subclass, "_type") and not getattr(subclass, "__abstractmethods__", None) +} + + +@dataclass +class ActionCall(CogniteObject, ABC): + """Base class for action calls requested by the agent.""" + + _type: ClassVar[str] + action_id: str + + @classmethod + def _load(cls, data: dict[str, Any], cognite_client: CogniteClient | None = None) -> ActionCall: + """Dispatch to the correct concrete action call class based on `type`.""" + action_type = data.get("type", "") + action_class = _ACTION_CALL_CLS_BY_TYPE.get(action_type, UnknownActionCall) + return action_class._load_call(data, cognite_client) + + @abstractmethod + def dump(self, camel_case: bool = True) -> dict[str, Any]: + """Dump the action call to a dictionary.""" + ... + + @classmethod + @abstractmethod + def _load_call(cls, data: dict[str, Any], cognite_client: CogniteClient | None = None) -> ActionCall: + """Create a concrete action call instance from raw data.""" + ... + + +@dataclass +class ClientToolCall(ActionCall): + """A client tool call requested by the agent. + + Args: + action_id (str): The unique identifier for this action call. + name (str): The name of the client tool being called. + arguments (dict[str, Any]): The parsed arguments for the tool call. + """ + + _type: ClassVar[str] = "clientTool" + action_id: str + name: str + arguments: dict[str, Any] + + def dump(self, camel_case: bool = True) -> dict[str, Any]: + return { + "type": self._type, + "actionId" if camel_case else "action_id": self.action_id, + "clientTool": { + "name": self.name, + "arguments": json.dumps(self.arguments), + }, + } + + @classmethod + def _load_call(cls, data: dict[str, Any], cognite_client: CogniteClient | None = None) -> ClientToolCall: + client_tool = data["clientTool"] + arguments_str = client_tool["arguments"] + return cls( + action_id=data["actionId"], + name=client_tool["name"], + arguments=json.loads(arguments_str), + ) + + +@dataclass +class UnknownActionCall(ActionCall): + """Unknown action call type for forward compatibility. + + Args: + action_id (str): The unique identifier for this action call. + type (str): The action call type. + data (dict[str, Any]): The raw action call data. + """ + + action_id: str + type: str + data: dict[str, Any] = field(default_factory=dict) + + def dump(self, camel_case: bool = True) -> dict[str, Any]: + result = self.data.copy() + result["type"] = self.type + result["actionId" if camel_case else "action_id"] = self.action_id + return result + + @classmethod + def _load_call(cls, data: dict[str, Any], cognite_client: CogniteClient | None = None) -> UnknownActionCall: + action_type = data.get("type", "") + action_id = data.get("actionId", "") + return cls(action_id=action_id, data=data, type=action_type) + + +# Build the mapping AFTER concrete classes are defined +_ACTION_CALL_CLS_BY_TYPE: dict[str, type[ActionCall]] = { + subclass._type: subclass # type: ignore[type-abstract] + for subclass in ActionCall.__subclasses__() + if hasattr(subclass, "_type") and not getattr(subclass, "__abstractmethods__", None) +} + + @dataclass class Message(CogniteResource): """A message to send to an agent. @@ -123,6 +428,115 @@ class MessageList(CogniteResourceList[Message]): _RESOURCE = Message +@dataclass +class ActionResult(CogniteResource, ABC): + """Base class for action execution results.""" + + _type: ClassVar[str] + action_id: str + role: Literal["action"] = "action" + + @classmethod + def _load(cls, data: dict[str, Any], cognite_client: CogniteClient | None = None) -> ActionResult: + """Dispatch to the correct concrete action result class based on `type`.""" + action_type = data.get("type", "") + action_class = _ACTION_RESULT_CLS_BY_TYPE.get(action_type, UnknownActionResult) + return action_class._load_result(data, cognite_client) + + @abstractmethod + def dump(self, camel_case: bool = True) -> dict[str, Any]: + """Dump the action result to a dictionary.""" + ... + + @classmethod + @abstractmethod + def _load_result(cls, data: dict[str, Any], cognite_client: CogniteClient | None = None) -> ActionResult: + """Create a concrete action result instance from raw data.""" + ... + + +@dataclass +class ClientToolResult(ActionResult): + """Result of executing a client tool, for sending back to the agent. + + Args: + action_id (str): The ID of the action being responded to. + content (str | MessageContent): The result of executing the action. + data (list[Any] | None): Optional structured data. + """ + + _type: ClassVar[str] = "clientTool" + content: MessageContent = field(init=False) + data: list[Any] | None = field(init=False, default=None) + + def __init__( + self, + action_id: str, + content: str | MessageContent, + data: list[Any] | None = None, + ) -> None: + self.action_id = action_id + if isinstance(content, str): + self.content = TextContent(text=content) + else: + self.content = content + self.data = data + self.role = "action" + + def dump(self, camel_case: bool = True) -> dict[str, Any]: + return { + "role": self.role, + "type": self._type, + "actionId" if camel_case else "action_id": self.action_id, + "content": self.content.dump(camel_case=camel_case), + "data": self.data or [], + } + + @classmethod + def _load_result(cls, data: dict[str, Any], cognite_client: CogniteClient | None = None) -> ClientToolResult: + content = MessageContent._load(data["content"]) + return cls( + action_id=data["actionId"], + content=content, + data=data.get("data"), + ) + + +@dataclass +class UnknownActionResult(ActionResult): + """Unknown action result type for forward compatibility. + + Args: + action_id (str): The ID of the action being responded to. + type (str): The action result type. + data (dict[str, Any]): The raw action result data. + """ + + type: str = "" + data: dict[str, Any] = field(default_factory=dict) + + def dump(self, camel_case: bool = True) -> dict[str, Any]: + result = self.data.copy() + result["role"] = self.role + result["type"] = self.type + result["actionId" if camel_case else "action_id"] = self.action_id + return result + + @classmethod + def _load_result(cls, data: dict[str, Any], cognite_client: CogniteClient | None = None) -> UnknownActionResult: + action_type = data.get("type", "") + action_id = data.get("actionId", "") + return cls(action_id=action_id, data=data, type=action_type) + + +# Build the mapping AFTER concrete classes are defined +_ACTION_RESULT_CLS_BY_TYPE: dict[str, type[ActionResult]] = { + subclass._type: subclass # type: ignore[type-abstract] + for subclass in ActionResult.__subclasses__() + if hasattr(subclass, "_type") and not getattr(subclass, "__abstractmethods__", None) +} + + @dataclass class AgentDataItem(CogniteObject): """Data item in agent response. @@ -177,12 +591,14 @@ class AgentMessage(CogniteResource): content (MessageContent | None): The message content. data (list[AgentDataItem] | None): Data items in the response. reasoning (list[AgentReasoningItem] | None): Reasoning items in the response. + actions (list[ActionCall] | None): Action calls requested by the agent. role (Literal["agent"]): The role of the message sender. """ content: MessageContent | None = None data: list[AgentDataItem] | None = None reasoning: list[AgentReasoningItem] | None = None + actions: list[ActionCall] | None = None role: Literal["agent"] = "agent" def dump(self, camel_case: bool = True) -> dict[str, Any]: @@ -193,6 +609,8 @@ def dump(self, camel_case: bool = True) -> dict[str, Any]: result["data"] = [item.dump(camel_case=camel_case) for item in self.data] if self.reasoning is not None: result["reasoning"] = [item.dump(camel_case=camel_case) for item in self.reasoning] + if self.actions is not None: + result["actions"] = [item.dump(camel_case=camel_case) for item in self.actions] return result @classmethod @@ -200,10 +618,12 @@ def _load(cls, data: dict[str, Any], cognite_client: CogniteClient | None = None content = MessageContent._load(data["content"]) if "content" in data else None data_items = [AgentDataItem._load(item, cognite_client) for item in data.get("data", [])] reasoning_items = [AgentReasoningItem._load(item, cognite_client) for item in data.get("reasoning", [])] + action_calls = [ActionCall._load(item, cognite_client) for item in data.get("actions", [])] return cls( content=content, data=data_items if data_items else None, reasoning=reasoning_items if reasoning_items else None, + actions=action_calls if action_calls else None, role=data["role"], ) @@ -256,6 +676,17 @@ def text(self) -> str | None: return message.content.text return None + @property + def action_calls(self) -> list[ActionCall] | None: + """Get all action calls from all messages.""" + if self.messages: + all_actions = [] + for message in self.messages: + if message.actions: + all_actions.extend(message.actions) + return all_actions if all_actions else None + return None + @classmethod def _load(cls, data: dict[str, Any], cognite_client: CogniteClient | None = None) -> AgentChatResponse: response_data = data["response"] diff --git a/cognite/client/utils/_function_introspection.py b/cognite/client/utils/_function_introspection.py new file mode 100644 index 0000000000..ebbc2ad959 --- /dev/null +++ b/cognite/client/utils/_function_introspection.py @@ -0,0 +1,279 @@ +"""Utilities for introspecting Python functions to generate JSON Schema. + +This module provides utilities for converting Python function signatures and docstrings +into JSON Schema definitions, primarily for use with agent function calling features. +""" + +from __future__ import annotations + +import inspect +import re +from collections.abc import Callable +from typing import Any, get_args, get_origin, get_type_hints + + +def parse_google_docstring(docstring: str | None) -> tuple[str, dict[str, str]]: + """Parse Google-style docstring to extract description and parameter descriptions. + + Args: + docstring (str | None): The function's docstring. + + Returns: + tuple[str, dict[str, str]]: Tuple of (function_description, param_descriptions_dict). + + Examples: + >>> def example(param1: str, param2: int) -> None: + ... '''Example function. + ... + ... Args: + ... param1: First parameter. + ... param2: Second parameter. + ... ''' + ... pass + >>> desc, params = parse_google_docstring(example.__doc__) + >>> desc + 'Example function.' + >>> params['param1'] + 'First parameter.' + """ + if not docstring: + return "", {} + + # Split into lines + lines = docstring.strip().split("\n") + + # Extract function description (everything before Args section) + description_lines = [] + param_descriptions = {} + in_args_section = False + current_param = None + current_param_desc_lines = [] + + for line in lines: + stripped = line.strip() + + # Check if we're entering Args section + if stripped in ("Args:", "Arguments:", "Parameters:"): + in_args_section = True + continue + + # Check if we're leaving Args section (Returns, Raises, etc.) + if in_args_section and stripped and stripped.endswith(":"): + # Save any pending parameter description + if current_param and current_param_desc_lines: + param_descriptions[current_param] = " ".join(current_param_desc_lines).strip() + break + + # If we're not in args section yet, check if we hit Returns/Raises (end of description) + if ( + not in_args_section + and stripped + and stripped.endswith(":") + and stripped + in ( + "Returns:", + "Return:", + "Raises:", + "Raises:", + "Yields:", + "Yield:", + "Note:", + "Notes:", + "Example:", + "Examples:", + ) + ): + break + + if in_args_section: + # Parse parameter line: "param_name: description" or "param_name (type): description" + param_match = re.match(r"^\s*(\w+)(?:\s*\([^)]+\))?\s*:\s*(.*)$", line) + if param_match: + # Save previous parameter if any + if current_param and current_param_desc_lines: + param_descriptions[current_param] = " ".join(current_param_desc_lines).strip() + + # Start new parameter + current_param = param_match.group(1) + current_param_desc_lines = [param_match.group(2)] if param_match.group(2) else [] + elif current_param and stripped: + # Continuation of current parameter description + current_param_desc_lines.append(stripped) + else: + # Part of function description + if stripped: + description_lines.append(stripped) + + # Save last parameter + if current_param and current_param_desc_lines: + param_descriptions[current_param] = " ".join(current_param_desc_lines).strip() + + # Join description lines + description = " ".join(description_lines).strip() + + return description, param_descriptions + + +def type_to_json_schema(param_type: type, param_name: str) -> dict[str, Any]: + """Convert Python type hint to JSON Schema type. + + Supports primitive types (int, float, str, bool) and lists of primitives. + + Args: + param_type (type): The type annotation. + param_name (str): The parameter name (for error messages). + + Returns: + dict[str, Any]: JSON Schema type object. + + Raises: + TypeError: If the type is not supported. + + Examples: + >>> type_to_json_schema(int, "count") + {'type': 'integer'} + >>> type_to_json_schema(list[str], "names") + {'type': 'array', 'items': {'type': 'string'}} + """ + # Handle primitives + if param_type is int: + return {"type": "integer"} + elif param_type is float: + return {"type": "number"} + elif param_type is str: + return {"type": "string"} + elif param_type is bool: + return {"type": "boolean"} + + # Check for bare list type (without item type annotation) + if param_type is list: + raise TypeError(f"Parameter '{param_name}' has type 'list' without item type. Use list[int], list[str], etc.") + + # Handle list types with item type annotation + origin = get_origin(param_type) + if origin is list: + args = get_args(param_type) + if not args: + raise TypeError( + f"Parameter '{param_name}' has type 'list' without item type. Use list[int], list[str], etc." + ) + + item_type = args[0] + # Only support primitives in lists + if item_type is int: + return {"type": "array", "items": {"type": "integer"}} + elif item_type is float: + return {"type": "array", "items": {"type": "number"}} + elif item_type is str: + return {"type": "array", "items": {"type": "string"}} + elif item_type is bool: + return {"type": "array", "items": {"type": "boolean"}} + else: + raise TypeError( + f"Parameter '{param_name}' has unsupported list item type '{item_type.__name__}'. " + f"Supported list types: list[int], list[float], list[str], list[bool]" + ) + + # Unsupported type + type_name = getattr(param_type, "__name__", str(param_type)) + raise TypeError( + f"Parameter '{param_name}' has unsupported type '{type_name}'. " + f"Supported types: int, float, str, bool, list[int], list[float], list[str], list[bool]" + ) + + +def function_to_json_schema(func: Callable[..., Any], description: str | None = None) -> tuple[str, dict[str, Any]]: + """Generate JSON Schema from a Python function's signature and docstring. + + This function introspects a Python function to extract its signature, type hints, + and docstring, then generates a JSON Schema that describes the function's parameters. + + Args: + func (Callable[..., Any]): The Python function to introspect. + description (str | None): Optional description override. If not provided, extracted from docstring. + + Returns: + tuple[str, dict[str, Any]]: Tuple of (function_description, parameters_schema). + + Raises: + TypeError: If a parameter is missing a type annotation or has an unsupported type. + + Examples: + >>> def add(a: float, b: float) -> float: + ... '''Add two numbers. + ... + ... Args: + ... a: First number. + ... b: Second number. + ... ''' + ... return a + b + >>> desc, schema = function_to_json_schema(add) + >>> desc + 'Add two numbers.' + >>> schema['properties']['a'] + {'type': 'number', 'description': 'First number.'} + >>> schema['required'] + ['a', 'b'] + """ + # Get function signature + sig = inspect.signature(func) + + # Get type hints + try: + type_hints = get_type_hints(func) + except Exception as e: + raise TypeError(f"Failed to get type hints for function '{func.__name__}': {e}") from e + + # Parse docstring if description not provided + if description is None: + description, param_descriptions = parse_google_docstring(func.__doc__) + # Use function name as description if no docstring + if not description: + description = func.__name__ + else: + _, param_descriptions = parse_google_docstring(func.__doc__) + + # Build JSON Schema + properties = {} + required = [] + + for param_name, param in sig.parameters.items(): + # Skip *args, **kwargs, self, cls + if param.kind in ( + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + ): + continue + + # Check if parameter has type hint + if param_name not in type_hints: + raise TypeError(f"Parameter '{param_name}' is missing type annotation") + + param_type = type_hints[param_name] + + # Convert type to JSON Schema + schema_type = type_to_json_schema(param_type, param_name) + + # Add description if available + if param_name in param_descriptions: + schema_type["description"] = param_descriptions[param_name] + + properties[param_name] = schema_type + + # Check if required (no default or default is not None for Optional types) + if param.default is inspect.Parameter.empty: + required.append(param_name) + elif param.default is not None: + # Has a non-None default, so it's required but with a default + required.append(param_name) + # If default is None, it's optional (not in required list) + + # Build final parameters schema + parameters_schema = { + "type": "object", + "properties": properties, + } + if required: + parameters_schema["required"] = required + + return description, parameters_schema diff --git a/tests/tests_unit/test_api/test_agents_actions.py b/tests/tests_unit/test_api/test_agents_actions.py new file mode 100644 index 0000000000..10e34b8110 --- /dev/null +++ b/tests/tests_unit/test_api/test_agents_actions.py @@ -0,0 +1,494 @@ +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +from cognite.client import CogniteClient +from cognite.client.data_classes.agents import Message +from cognite.client.data_classes.agents.chat import ( + ActionCall, + AgentChatResponse, + ClientToolAction, + ClientToolCall, + ClientToolResult, + TextContent, + UnknownActionCall, +) + + +@pytest.fixture +def action_call_response_body() -> dict: + """Response body when agent requests an action.""" + return { + "agentId": "my_agent", + "agentExternalId": "my_agent", + "response": { + "cursor": "cursor_12345", + "messages": [ + { + "content": {"text": "", "type": "text"}, + "actions": [ + { + "type": "clientTool", + "actionId": "call_abc123", + "clientTool": { + "name": "add", + "arguments": '{"a": 42, "b": 58}', + }, + } + ], + "role": "agent", + } + ], + "type": "result", + }, + } + + +@pytest.fixture +def final_response_body() -> dict: + """Final response after action execution.""" + return { + "agentId": "my_agent", + "agentExternalId": "my_agent", + "response": { + "cursor": "cursor_67890", + "messages": [ + { + "content": {"text": "The result is 100.", "type": "text"}, + "role": "agent", + } + ], + "type": "result", + }, + } + + +class TestClientToolCall: + def test_dump_serializes_arguments_to_json(self) -> None: + call = ClientToolCall( + action_id="call_456", + name="multiply", + arguments={"x": 5, "y": 7}, + ) + + dumped = call.dump(camel_case=True) + + assert dumped["type"] == "clientTool" + assert dumped["actionId"] == "call_456" + assert dumped["clientTool"]["name"] == "multiply" + # Arguments should be JSON string + import json + + assert json.loads(dumped["clientTool"]["arguments"]) == {"x": 5, "y": 7} + + +class TestActionCallPolymorphism: + def test_unknown_action_call_loaded_for_unknown_type(self) -> None: + data = { + "type": "unknownActionType", + "actionId": "call_999", + "someField": "someValue", + } + + call = ActionCall._load(data) + + assert isinstance(call, UnknownActionCall) + assert call.action_id == "call_999" + assert call.type == "unknownActionType" + + +class TestClientToolResult: + def test_init_with_message_content(self) -> None: + text_content = TextContent(text="Result: 100") + result = ClientToolResult( + action_id="call_456", + content=text_content, + ) + + assert result.action_id == "call_456" + assert result.content is text_content + assert result.content.text == "Result: 100" + + def test_load(self) -> None: + data = { + "type": "clientTool", + "actionId": "call_abc", + "role": "action", + "content": {"text": "Done", "type": "text"}, + "data": [], + } + + result = ClientToolResult._load_result(data) + + assert isinstance(result, ClientToolResult) + assert result.action_id == "call_abc" + assert result.content.text == "Done" + + +class TestChatWithActions: + def test_chat_with_actions_parameter(self, cognite_client: CogniteClient, action_call_response_body: dict) -> None: + cognite_client.agents._post = MagicMock(return_value=MagicMock(json=lambda: action_call_response_body)) + + add_action = ClientToolAction( + name="add", + description="Add two numbers", + parameters={ + "type": "object", + "properties": { + "a": {"type": "number"}, + "b": {"type": "number"}, + }, + "required": ["a", "b"], + }, + ) + + response = cognite_client.agents.chat( + agent_id="my_agent", + messages=Message("What is 42 plus 58?"), + actions=[add_action], + ) + + # Verify actions were sent in request + call_args = cognite_client.agents._post.call_args + assert "actions" in call_args[1]["json"] + assert len(call_args[1]["json"]["actions"]) == 1 + assert call_args[1]["json"]["actions"][0]["type"] == "clientTool" + assert call_args[1]["json"]["actions"][0]["clientTool"]["name"] == "add" + + # Verify response + assert isinstance(response, AgentChatResponse) + assert response.action_calls is not None + assert len(response.action_calls) == 1 + + def test_chat_with_action_result_message(self, cognite_client: CogniteClient, final_response_body: dict) -> None: + cognite_client.agents._post = MagicMock(return_value=MagicMock(json=lambda: final_response_body)) + + add_action = ClientToolAction( + name="add", + description="Add two numbers", + parameters={"type": "object", "properties": {"a": {"type": "number"}, "b": {"type": "number"}}}, + ) + + result = ClientToolResult( + action_id="call_abc123", + content="The result is 100", + ) + + response = cognite_client.agents.chat( + agent_id="my_agent", + messages=result, + cursor="cursor_12345", + actions=[add_action], + ) + + # Verify action result was sent + call_args = cognite_client.agents._post.call_args + assert call_args[1]["json"]["cursor"] == "cursor_12345" + assert len(call_args[1]["json"]["messages"]) == 1 + msg = call_args[1]["json"]["messages"][0] + assert msg["role"] == "action" + assert msg["type"] == "clientTool" + assert msg["actionId"] == "call_abc123" + + # Verify response + assert isinstance(response, AgentChatResponse) + assert response.text == "The result is 100." + + +class TestFromFunction: + """Test ClientToolAction.from_function() method.""" + + def test_basic_function_with_primitives(self) -> None: + """Test basic function with primitive types.""" + + def add(a: float, b: float) -> float: + """Add two numbers. + + Args: + a: The first number. + b: The second number. + + Returns: + The sum. + """ + return a + b + + action = ClientToolAction.from_function(add) + + assert action.name == "add" + assert action.description == "Add two numbers." + assert action.parameters == { + "type": "object", + "properties": { + "a": {"type": "number", "description": "The first number."}, + "b": {"type": "number", "description": "The second number."}, + }, + "required": ["a", "b"], + } + + def test_all_primitive_types(self) -> None: + """Test all supported primitive types.""" + + def example( + name: str, + count: int, + value: float, + enabled: bool, + ) -> None: + """Example function. + + Args: + name: A string parameter. + count: An integer parameter. + value: A float parameter. + enabled: A boolean parameter. + """ + pass + + action = ClientToolAction.from_function(example) + + assert action.parameters["properties"]["name"] == { + "type": "string", + "description": "A string parameter.", + } + assert action.parameters["properties"]["count"] == { + "type": "integer", + "description": "An integer parameter.", + } + assert action.parameters["properties"]["value"] == { + "type": "number", + "description": "A float parameter.", + } + assert action.parameters["properties"]["enabled"] == { + "type": "boolean", + "description": "A boolean parameter.", + } + assert action.parameters["required"] == ["name", "count", "value", "enabled"] + + def test_list_types(self) -> None: + """Test list parameter types.""" + + def process_lists( + names: list[str], + scores: list[float], + counts: list[int], + flags: list[bool], + ) -> None: + """Process various lists. + + Args: + names: List of names. + scores: List of scores. + counts: List of counts. + flags: List of flags. + """ + pass + + action = ClientToolAction.from_function(process_lists) + + assert action.parameters["properties"]["names"] == { + "type": "array", + "items": {"type": "string"}, + "description": "List of names.", + } + assert action.parameters["properties"]["scores"] == { + "type": "array", + "items": {"type": "number"}, + "description": "List of scores.", + } + assert action.parameters["properties"]["counts"] == { + "type": "array", + "items": {"type": "integer"}, + "description": "List of counts.", + } + assert action.parameters["properties"]["flags"] == { + "type": "array", + "items": {"type": "boolean"}, + "description": "List of flags.", + } + + def test_optional_parameters(self) -> None: + """Test optional parameters (default=None).""" + + def greet(name: str, title: str | None = None) -> str: + """Greet a person. + + Args: + name: The person's name. + title: Optional title. + """ + return f"Hello, {title} {name}!" if title else f"Hello, {name}!" + + action = ClientToolAction.from_function(greet) + + assert "name" in action.parameters["required"] + assert "title" not in action.parameters["required"] + assert "title" in action.parameters["properties"] + + def test_custom_name(self) -> None: + """Test custom action name override.""" + + def add(a: float, b: float) -> float: + """Add two numbers.""" + return a + b + + action = ClientToolAction.from_function(add, name="custom_add") + + assert action.name == "custom_add" + + def test_no_docstring_uses_function_name(self) -> None: + """Test that missing docstring uses function name with warning.""" + + def add(a: float, b: float) -> float: + return a + b + + with pytest.warns(UserWarning, match="has no docstring"): + action = ClientToolAction.from_function(add) + + assert action.name == "add" + assert action.description == "add" + + def test_multiline_param_description(self) -> None: + """Test parameter descriptions that span multiple lines.""" + + def example(param: str) -> None: + """Example function. + + Args: + param: This is a long description + that spans multiple lines + in the docstring. + """ + pass + + action = ClientToolAction.from_function(example) + + assert ( + action.parameters["properties"]["param"]["description"] + == "This is a long description that spans multiple lines in the docstring." + ) + + def test_multiline_function_description(self) -> None: + """Test function description that spans multiple lines.""" + + def example(param: str) -> None: + """This is a function description + that spans multiple lines + before the Args section. + + Args: + param: A parameter. + """ + pass + + action = ClientToolAction.from_function(example) + + assert action.description == "This is a function description that spans multiple lines before the Args section." + + def test_missing_type_annotation_raises_error(self) -> None: + """Test that missing type annotation raises TypeError.""" + + def bad(param) -> None: + """Bad function.""" + pass + + with pytest.raises(TypeError, match="missing type annotation"): + ClientToolAction.from_function(bad) + + def test_unsupported_type_raises_error(self) -> None: + """Test that unsupported types raise TypeError.""" + + def bad(data: dict) -> None: + """Bad function. + + Args: + data: A dict parameter. + """ + pass + + with pytest.raises(TypeError, match="unsupported type"): + ClientToolAction.from_function(bad) + + def test_unsupported_list_item_type_raises_error(self) -> None: + """Test that unsupported list item types raise TypeError.""" + + def bad(items: list[dict]) -> None: + """Bad function. + + Args: + items: A list of dicts. + """ + pass + + with pytest.raises(TypeError, match="unsupported list item type"): + ClientToolAction.from_function(bad) + + def test_list_without_item_type_raises_error(self) -> None: + """Test that bare list type raises TypeError.""" + + def bad(items: list) -> None: + """Bad function. + + Args: + items: A list. + """ + pass + + with pytest.raises(TypeError, match="without item type"): + ClientToolAction.from_function(bad) + + def test_function_with_no_parameters(self) -> None: + """Test function with no parameters.""" + + def get_status() -> str: + """Get current status. + + Returns: + The status string. + """ + return "OK" + + action = ClientToolAction.from_function(get_status) + + assert action.name == "get_status" + assert action.description == "Get current status." + assert action.parameters == { + "type": "object", + "properties": {}, + } + assert "required" not in action.parameters + + def test_param_without_description_in_docstring(self) -> None: + """Test parameter without description in docstring.""" + + def example(param: str) -> None: + """Example function.""" + pass + + action = ClientToolAction.from_function(example) + + # Parameter should exist but without description + assert "param" in action.parameters["properties"] + assert "description" not in action.parameters["properties"]["param"] + + def test_dump_and_use_in_chat(self) -> None: + """Test that generated action can be dumped and used in chat.""" + + def add(a: float, b: float) -> float: + """Add two numbers. + + Args: + a: First number. + b: Second number. + """ + return a + b + + action = ClientToolAction.from_function(add) + dumped = action.dump() + + assert dumped["type"] == "clientTool" + assert dumped["clientTool"]["name"] == "add" + assert dumped["clientTool"]["description"] == "Add two numbers." + assert dumped["clientTool"]["parameters"]["type"] == "object" diff --git a/tests/tests_unit/test_utils/test_function_introspection.py b/tests/tests_unit/test_utils/test_function_introspection.py new file mode 100644 index 0000000000..0cad30bbd5 --- /dev/null +++ b/tests/tests_unit/test_utils/test_function_introspection.py @@ -0,0 +1,288 @@ +"""Unit tests for function introspection utilities.""" + +from __future__ import annotations + +import pytest + +from cognite.client.utils._function_introspection import ( + function_to_json_schema, + parse_google_docstring, + type_to_json_schema, +) + + +class TestParseGoogleDocstring: + """Tests for parse_google_docstring function.""" + + def test_basic_docstring(self) -> None: + """Test parsing a basic docstring.""" + docstring = """Short description. + + Args: + param1: First parameter. + param2: Second parameter. + + Returns: + Some value. + """ + desc, params = parse_google_docstring(docstring) + + assert desc == "Short description." + assert params["param1"] == "First parameter." + assert params["param2"] == "Second parameter." + + def test_multiline_description(self) -> None: + """Test multiline function description.""" + docstring = """This is a long description + that spans multiple lines + before the Args section. + + Args: + param: A parameter. + """ + desc, params = parse_google_docstring(docstring) + + assert desc == "This is a long description that spans multiple lines before the Args section." + assert params["param"] == "A parameter." + + def test_multiline_param_description(self) -> None: + """Test multiline parameter descriptions.""" + docstring = """Function description. + + Args: + param: This is a long description + that spans multiple lines + in the docstring. + """ + desc, params = parse_google_docstring(docstring) + + assert desc == "Function description." + assert params["param"] == "This is a long description that spans multiple lines in the docstring." + + def test_param_with_type_annotation(self) -> None: + """Test parameter with type annotation in docstring.""" + docstring = """Function description. + + Args: + param (str): String parameter. + count (int): Integer parameter. + """ + desc, params = parse_google_docstring(docstring) + + assert params["param"] == "String parameter." + assert params["count"] == "Integer parameter." + + def test_no_docstring(self) -> None: + """Test with None docstring.""" + desc, params = parse_google_docstring(None) + + assert desc == "" + assert params == {} + + def test_no_args_section(self) -> None: + """Test docstring without Args section.""" + docstring = """Just a description. + + Returns: + Some value. + """ + desc, params = parse_google_docstring(docstring) + + assert desc == "Just a description." + assert params == {} + + +class TestTypeToJsonSchema: + """Tests for type_to_json_schema function.""" + + def test_primitive_int(self) -> None: + """Test int type.""" + schema = type_to_json_schema(int, "param") + assert schema == {"type": "integer"} + + def test_primitive_float(self) -> None: + """Test float type.""" + schema = type_to_json_schema(float, "param") + assert schema == {"type": "number"} + + def test_primitive_str(self) -> None: + """Test str type.""" + schema = type_to_json_schema(str, "param") + assert schema == {"type": "string"} + + def test_primitive_bool(self) -> None: + """Test bool type.""" + schema = type_to_json_schema(bool, "param") + assert schema == {"type": "boolean"} + + def test_list_int(self) -> None: + """Test list[int] type.""" + schema = type_to_json_schema(list[int], "param") + assert schema == {"type": "array", "items": {"type": "integer"}} + + def test_list_float(self) -> None: + """Test list[float] type.""" + schema = type_to_json_schema(list[float], "param") + assert schema == {"type": "array", "items": {"type": "number"}} + + def test_list_str(self) -> None: + """Test list[str] type.""" + schema = type_to_json_schema(list[str], "param") + assert schema == {"type": "array", "items": {"type": "string"}} + + def test_list_bool(self) -> None: + """Test list[bool] type.""" + schema = type_to_json_schema(list[bool], "param") + assert schema == {"type": "array", "items": {"type": "boolean"}} + + def test_bare_list_raises_error(self) -> None: + """Test that bare list type raises TypeError.""" + with pytest.raises(TypeError, match="without item type"): + type_to_json_schema(list, "param") + + def test_unsupported_type_raises_error(self) -> None: + """Test that unsupported types raise TypeError.""" + with pytest.raises(TypeError, match="unsupported type"): + type_to_json_schema(dict, "param") + + def test_unsupported_list_item_type_raises_error(self) -> None: + """Test that unsupported list item types raise TypeError.""" + with pytest.raises(TypeError, match="unsupported list item type"): + type_to_json_schema(list[dict], "param") + + +class TestFunctionToJsonSchema: + """Tests for function_to_json_schema function.""" + + def test_basic_function(self) -> None: + """Test basic function with primitives.""" + + def add(a: float, b: float) -> float: + """Add two numbers. + + Args: + a: First number. + b: Second number. + """ + return a + b + + desc, schema = function_to_json_schema(add) + + assert desc == "Add two numbers." + assert schema == { + "type": "object", + "properties": { + "a": {"type": "number", "description": "First number."}, + "b": {"type": "number", "description": "Second number."}, + }, + "required": ["a", "b"], + } + + def test_all_primitive_types(self) -> None: + """Test all supported primitive types.""" + + def example(name: str, count: int, value: float, enabled: bool) -> None: + """Example function. + + Args: + name: A string. + count: An integer. + value: A float. + enabled: A boolean. + """ + pass + + desc, schema = function_to_json_schema(example) + + assert schema["properties"]["name"]["type"] == "string" + assert schema["properties"]["count"]["type"] == "integer" + assert schema["properties"]["value"]["type"] == "number" + assert schema["properties"]["enabled"]["type"] == "boolean" + assert schema["required"] == ["name", "count", "value", "enabled"] + + def test_optional_parameters(self) -> None: + """Test optional parameters (default=None).""" + + def greet(name: str, title: str | None = None) -> str: + """Greet a person. + + Args: + name: The person's name. + title: Optional title. + """ + return f"Hello, {title} {name}!" if title else f"Hello, {name}!" + + desc, schema = function_to_json_schema(greet) + + assert "name" in schema["required"] + assert "title" not in schema["required"] + assert "title" in schema["properties"] + + def test_function_with_no_parameters(self) -> None: + """Test function with no parameters.""" + + def get_status() -> str: + """Get current status. + + Returns: + The status. + """ + return "OK" + + desc, schema = function_to_json_schema(get_status) + + assert desc == "Get current status." + assert schema == {"type": "object", "properties": {}} + assert "required" not in schema + + def test_function_with_no_docstring(self) -> None: + """Test function with no docstring.""" + + def add(a: float, b: float) -> float: + return a + b + + desc, schema = function_to_json_schema(add) + + assert desc == "add" + assert "a" in schema["properties"] + assert "b" in schema["properties"] + + def test_custom_description_override(self) -> None: + """Test custom description override.""" + + def add(a: float, b: float) -> float: + """Add two numbers.""" + return a + b + + desc, schema = function_to_json_schema(add, description="Custom description") + + assert desc == "Custom description" + + def test_missing_type_annotation_raises_error(self) -> None: + """Test that missing type annotation raises TypeError.""" + + def bad(param) -> None: + """Bad function.""" + pass + + with pytest.raises(TypeError, match="missing type annotation"): + function_to_json_schema(bad) + + def test_list_types(self) -> None: + """Test list parameter types.""" + + def process_lists(names: list[str], scores: list[float]) -> None: + """Process lists. + + Args: + names: List of names. + scores: List of scores. + """ + pass + + desc, schema = function_to_json_schema(process_lists) + + assert schema["properties"]["names"]["type"] == "array" + assert schema["properties"]["names"]["items"]["type"] == "string" + assert schema["properties"]["scores"]["type"] == "array" + assert schema["properties"]["scores"]["items"]["type"] == "number"