From d1b9e60dbfa463c92b23dee0c2bfb62f589cec32 Mon Sep 17 00:00:00 2001 From: Kelvin Sundli Date: Tue, 30 Sep 2025 19:53:36 -0700 Subject: [PATCH 1/5] feat(agents): add support for client-side actions in agent chat Add comprehensive support for client-side action execution in agent chat: - Action base class with ClientToolAction and UnknownAction implementations - ActionCall base class with ClientToolCall and UnknownActionCall for agent requests - ActionResult message type for sending action execution results back to agent - Updated AgentMessage to include actions field - Updated chat() API to accept actions parameter and ActionResult messages - Added action_calls property to AgentChatResponse for convenient access - Following existing patterns from AgentTool and MessageContent implementations This enables users to define custom client-side functions that the agent can call during reasoning, receive the action calls, execute them, and send results back to continue the conversation. --- cognite/client/_api/agents/agents.py | 59 +++- .../client/data_classes/agents/__init__.py | 14 + cognite/client/data_classes/agents/chat.py | 259 ++++++++++++++++++ 3 files changed, 325 insertions(+), 7 deletions(-) diff --git a/cognite/client/_api/agents/agents.py b/cognite/client/_api/agents/agents.py index bd177be879..c163860b02 100644 --- a/cognite/client/_api/agents/agents.py +++ b/cognite/client/_api/agents/agents.py @@ -1,11 +1,16 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING, overload +from typing import TYPE_CHECKING, Any, overload from cognite.client._api_client import APIClient from cognite.client.data_classes.agents import Agent, AgentList, AgentUpsert -from cognite.client.data_classes.agents.chat import AgentChatResponse, Message, MessageList +from cognite.client.data_classes.agents.chat import ( + Action, + ActionResult, + AgentChatResponse, + Message, +) from cognite.client.utils._experimental import FeaturePreviewWarning from cognite.client.utils._identifier import IdentifierSequence from cognite.client.utils.useful_types import SequenceNotStr @@ -243,8 +248,9 @@ def list(self) -> AgentList: # The API does not yet support limit or pagination def chat( self, agent_id: str, - messages: Message | Sequence[Message], + messages: Message | ActionResult | Sequence[Message | ActionResult], cursor: str | None = None, + actions: Sequence[Action] | None = None, ) -> AgentChatResponse: """`Chat with an agent. `_ @@ -253,9 +259,10 @@ def chat( Args: agent_id (str): External ID that uniquely identifies the agent. - messages (Message | Sequence[Message]): A list of one or many input messages to the agent. + messages (Message | ActionResult | Sequence[Message | ActionResult]): A list of one or many input messages to the agent. Can include regular messages and action results. cursor (str | None): The cursor to use for continuation of a conversation. Use this to create multi-turn conversations, as the cursor will keep track of the conversation state. + actions (Sequence[Action] | None): A list of client-side actions that can be called by the agent. Returns: AgentChatResponse: The response from the agent. @@ -290,22 +297,60 @@ def chat( ... Message("Once you have found it, find related time series.") ... ] ... ) + + Chat with client-side actions: + + >>> from cognite.client.data_classes.agents import ClientToolAction, ActionResult + >>> add_numbers_action = ClientToolAction( + ... name="add", + ... description="Add two numbers together", + ... parameters={ + ... "type": "object", + ... "properties": { + ... "a": {"type": "number", "description": "First number"}, + ... "b": {"type": "number", "description": "Second number"}, + ... }, + ... "required": ["a", "b"] + ... } + ... ) + >>> response = client.agents.chat( + ... agent_id="my_agent", + ... messages=Message("What is 42 plus 58?"), + ... actions=[add_numbers_action] + ... ) + >>> if response.action_calls: + ... for call in response.action_calls: + ... # Execute the action + ... result = call.arguments["a"] + call.arguments["b"] + ... # Send result back + ... response = client.agents.chat( + ... agent_id="my_agent", + ... messages=ActionResult( + ... action_id=call.action_id, + ... content=f"The result is {result}" + ... ), + ... cursor=response.cursor, + ... actions=[add_numbers_action] + ... ) """ self._warnings.warn() # Convert single message to list - if isinstance(messages, Message): + if isinstance(messages, (Message, ActionResult)): messages = [messages] # Build request body - body = { + body: dict[str, Any] = { "agentId": agent_id, - "messages": MessageList(messages).dump(camel_case=True), + "messages": [msg.dump(camel_case=True) for msg in messages], } if cursor is not None: body["cursor"] = cursor + if actions is not None: + body["actions"] = [action.dump(camel_case=True) for action in actions] + # Make the API call response = self._post( url_path=self._RESOURCE_PATH + "/chat", diff --git a/cognite/client/data_classes/agents/__init__.py b/cognite/client/data_classes/agents/__init__.py index a9931b6240..59b1de965c 100644 --- a/cognite/client/data_classes/agents/__init__.py +++ b/cognite/client/data_classes/agents/__init__.py @@ -21,19 +21,29 @@ ) from cognite.client.data_classes.agents.agents import Agent, AgentList, AgentUpsert, AgentUpsertList from cognite.client.data_classes.agents.chat import ( + Action, + ActionCall, + ActionResult, AgentChatResponse, AgentDataItem, AgentMessage, AgentMessageList, AgentReasoningItem, + ClientToolAction, + ClientToolCall, Message, MessageContent, MessageList, TextContent, + UnknownAction, + UnknownActionCall, UnknownContent, ) __all__ = [ + "Action", + "ActionCall", + "ActionResult", "Agent", "AgentChatResponse", "AgentDataItem", @@ -49,6 +59,8 @@ "AgentUpsertList", "AskDocumentAgentTool", "AskDocumentAgentToolUpsert", + "ClientToolAction", + "ClientToolCall", "DataModelInfo", "InstanceSpaces", "Message", @@ -62,6 +74,8 @@ "SummarizeDocumentAgentTool", "SummarizeDocumentAgentToolUpsert", "TextContent", + "UnknownAction", + "UnknownActionCall", "UnknownAgentTool", "UnknownAgentToolUpsert", "UnknownContent", diff --git a/cognite/client/data_classes/agents/chat.py b/cognite/client/data_classes/agents/chat.py index 1a8654ae2d..b8999505ff 100644 --- a/cognite/client/data_classes/agents/chat.py +++ b/cognite/client/data_classes/agents/chat.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, ClassVar, Literal @@ -85,6 +86,194 @@ def _load_content(cls, data: dict[str, Any]) -> UnknownContent: } +@dataclass +class Action(CogniteObject, ABC): + """Base class for all action types that can be provided to an agent.""" + + _type: ClassVar[str] + + @classmethod + def _load(cls, data: dict[str, Any], cognite_client: CogniteClient | None = None) -> Action: + """Dispatch to the correct concrete action class based on `type`.""" + action_type = data.get("type", "") + action_class = _ACTION_CLS_BY_TYPE.get(action_type, UnknownAction) + return action_class._load_action(data, cognite_client) + + @abstractmethod + def dump(self, camel_case: bool = True) -> dict[str, Any]: + """Dump the action to a dictionary.""" + ... + + @classmethod + @abstractmethod + def _load_action(cls, data: dict[str, Any], cognite_client: CogniteClient | None = None) -> Action: + """Create a concrete action instance from raw data.""" + ... + + +@dataclass +class ClientToolAction(Action): + """A client-side tool definition that can be called by the agent. + + Args: + name (str): The name of the client tool to call. + description (str): A description of what the function does. The language model will use this description when selecting the function and interpreting its parameters. + parameters (dict[str, Any]): The parameters the function accepts, described as a JSON Schema object. + """ + + _type: ClassVar[str] = "clientTool" + name: str + description: str + parameters: dict[str, Any] + + def dump(self, camel_case: bool = True) -> dict[str, Any]: + return { + "type": self._type, + "clientTool": { + "name": self.name, + "description": self.description, + "parameters": self.parameters, + }, + } + + @classmethod + def _load_action(cls, data: dict[str, Any], cognite_client: CogniteClient | None = None) -> ClientToolAction: + client_tool = data["clientTool"] + return cls( + name=client_tool["name"], + description=client_tool["description"], + parameters=client_tool["parameters"], + ) + + +@dataclass +class UnknownAction(Action): + """Unknown action type for forward compatibility. + + Args: + type (str): The action type. + data (dict[str, Any]): The raw action data. + """ + + type: str + data: dict[str, Any] = field(default_factory=dict) + + def dump(self, camel_case: bool = True) -> dict[str, Any]: + result = self.data.copy() + result["type"] = self.type + return result + + @classmethod + def _load_action(cls, data: dict[str, Any], cognite_client: CogniteClient | None = None) -> UnknownAction: + action_type = data.get("type", "") + return cls(data=data, type=action_type) + + +# Build the mapping AFTER concrete classes are defined +_ACTION_CLS_BY_TYPE: dict[str, type[Action]] = { + subclass._type: subclass # type: ignore[type-abstract] + for subclass in Action.__subclasses__() + if hasattr(subclass, "_type") and not getattr(subclass, "__abstractmethods__", None) +} + + +@dataclass +class ActionCall(CogniteObject, ABC): + """Base class for action calls requested by the agent.""" + + _type: ClassVar[str] + action_id: str + + @classmethod + def _load(cls, data: dict[str, Any], cognite_client: CogniteClient | None = None) -> ActionCall: + """Dispatch to the correct concrete action call class based on `type`.""" + action_type = data.get("type", "") + action_class = _ACTION_CALL_CLS_BY_TYPE.get(action_type, UnknownActionCall) + return action_class._load_call(data, cognite_client) + + @abstractmethod + def dump(self, camel_case: bool = True) -> dict[str, Any]: + """Dump the action call to a dictionary.""" + ... + + @classmethod + @abstractmethod + def _load_call(cls, data: dict[str, Any], cognite_client: CogniteClient | None = None) -> ActionCall: + """Create a concrete action call instance from raw data.""" + ... + + +@dataclass +class ClientToolCall(ActionCall): + """A client tool call requested by the agent. + + Args: + action_id (str): The unique identifier for this action call. + name (str): The name of the client tool being called. + arguments (dict[str, Any]): The parsed arguments for the tool call. + """ + + _type: ClassVar[str] = "clientTool" + action_id: str + name: str + arguments: dict[str, Any] + + def dump(self, camel_case: bool = True) -> dict[str, Any]: + return { + "type": self._type, + "actionId" if camel_case else "action_id": self.action_id, + "clientTool": { + "name": self.name, + "arguments": json.dumps(self.arguments), + }, + } + + @classmethod + def _load_call(cls, data: dict[str, Any], cognite_client: CogniteClient | None = None) -> ClientToolCall: + client_tool = data["clientTool"] + arguments_str = client_tool["arguments"] + return cls( + action_id=data["actionId"], + name=client_tool["name"], + arguments=json.loads(arguments_str), + ) + + +@dataclass +class UnknownActionCall(ActionCall): + """Unknown action call type for forward compatibility. + + Args: + action_id (str): The unique identifier for this action call. + type (str): The action call type. + data (dict[str, Any]): The raw action call data. + """ + + action_id: str + type: str + data: dict[str, Any] = field(default_factory=dict) + + def dump(self, camel_case: bool = True) -> dict[str, Any]: + result = self.data.copy() + result["type"] = self.type + result["actionId" if camel_case else "action_id"] = self.action_id + return result + + @classmethod + def _load_call(cls, data: dict[str, Any], cognite_client: CogniteClient | None = None) -> UnknownActionCall: + action_type = data.get("type", "") + action_id = data.get("actionId", "") + return cls(action_id=action_id, data=data, type=action_type) + + +# Build the mapping AFTER concrete classes are defined +_ACTION_CALL_CLS_BY_TYPE: dict[str, type[ActionCall]] = { + subclass._type: subclass # type: ignore[type-abstract] + for subclass in ActionCall.__subclasses__() + if hasattr(subclass, "_type") and not getattr(subclass, "__abstractmethods__", None) +} + + @dataclass class Message(CogniteResource): """A message to send to an agent. @@ -123,6 +312,59 @@ class MessageList(CogniteResourceList[Message]): _RESOURCE = Message +@dataclass +class ActionResult(CogniteResource): + """Result of executing a client action, for sending back to the agent. + + Args: + action_id (str): The ID of the action being responded to. + content (str | MessageContent): The result of executing the action. + action_type (str): The type of action (e.g., "clientTool"). Defaults to "clientTool". + data (list[Any] | None): Optional structured data. + """ + + action_id: str + content: MessageContent + action_type: str = "clientTool" + data: list[Any] | None = None + role: Literal["action"] = "action" + + def __init__( + self, + action_id: str, + content: str | MessageContent, + action_type: str = "clientTool", + data: list[Any] | None = None, + ) -> None: + self.action_id = action_id + if isinstance(content, str): + self.content = TextContent(text=content) + else: + self.content = content + self.action_type = action_type + self.data = data + self.role = "action" + + def dump(self, camel_case: bool = True) -> dict[str, Any]: + return { + "role": self.role, + "type": self.action_type, + "actionId" if camel_case else "action_id": self.action_id, + "content": self.content.dump(camel_case=camel_case), + "data": self.data or [], + } + + @classmethod + def _load(cls, data: dict[str, Any], cognite_client: CogniteClient | None = None) -> ActionResult: + content = MessageContent._load(data["content"]) + return cls( + action_id=data["actionId"], + content=content, + action_type=data.get("type", "clientTool"), + data=data.get("data"), + ) + + @dataclass class AgentDataItem(CogniteObject): """Data item in agent response. @@ -177,12 +419,14 @@ class AgentMessage(CogniteResource): content (MessageContent | None): The message content. data (list[AgentDataItem] | None): Data items in the response. reasoning (list[AgentReasoningItem] | None): Reasoning items in the response. + actions (list[ActionCall] | None): Action calls requested by the agent. role (Literal["agent"]): The role of the message sender. """ content: MessageContent | None = None data: list[AgentDataItem] | None = None reasoning: list[AgentReasoningItem] | None = None + actions: list[ActionCall] | None = None role: Literal["agent"] = "agent" def dump(self, camel_case: bool = True) -> dict[str, Any]: @@ -193,6 +437,8 @@ def dump(self, camel_case: bool = True) -> dict[str, Any]: result["data"] = [item.dump(camel_case=camel_case) for item in self.data] if self.reasoning is not None: result["reasoning"] = [item.dump(camel_case=camel_case) for item in self.reasoning] + if self.actions is not None: + result["actions"] = [item.dump(camel_case=camel_case) for item in self.actions] return result @classmethod @@ -200,10 +446,12 @@ def _load(cls, data: dict[str, Any], cognite_client: CogniteClient | None = None content = MessageContent._load(data["content"]) if "content" in data else None data_items = [AgentDataItem._load(item, cognite_client) for item in data.get("data", [])] reasoning_items = [AgentReasoningItem._load(item, cognite_client) for item in data.get("reasoning", [])] + action_calls = [ActionCall._load(item, cognite_client) for item in data.get("actions", [])] return cls( content=content, data=data_items if data_items else None, reasoning=reasoning_items if reasoning_items else None, + actions=action_calls if action_calls else None, role=data["role"], ) @@ -256,6 +504,17 @@ def text(self) -> str | None: return message.content.text return None + @property + def action_calls(self) -> list[ActionCall] | None: + """Get all action calls from all messages.""" + if self.messages: + all_actions = [] + for message in self.messages: + if message.actions: + all_actions.extend(message.actions) + return all_actions if all_actions else None + return None + @classmethod def _load(cls, data: dict[str, Any], cognite_client: CogniteClient | None = None) -> AgentChatResponse: response_data = data["response"] From a6ebeed97b19634abe522b72f146ca63fd0527d4 Mon Sep 17 00:00:00 2001 From: Kelvin Sundli Date: Tue, 30 Sep 2025 20:38:44 -0700 Subject: [PATCH 2/5] refactor(agents): make ActionResult a base class with ClientToolResult Make ActionResult follow the same polymorphic pattern as Action and ActionCall: - ActionResult is now an abstract base class - ClientToolResult is the concrete implementation for client tool execution results - UnknownActionResult provides forward compatibility - Updated API imports and documentation to use ClientToolResult - Updated test script to use ClientToolResult This ensures consistency across all action-related classes and allows future action types to have their own result schemas. --- cognite/client/_api/agents/agents.py | 4 +- .../client/data_classes/agents/__init__.py | 4 + cognite/client/data_classes/agents/chat.py | 82 ++++++++++++++++--- 3 files changed, 75 insertions(+), 15 deletions(-) diff --git a/cognite/client/_api/agents/agents.py b/cognite/client/_api/agents/agents.py index c163860b02..f6eedb9f22 100644 --- a/cognite/client/_api/agents/agents.py +++ b/cognite/client/_api/agents/agents.py @@ -300,7 +300,7 @@ def chat( Chat with client-side actions: - >>> from cognite.client.data_classes.agents import ClientToolAction, ActionResult + >>> from cognite.client.data_classes.agents import ClientToolAction, ClientToolResult >>> add_numbers_action = ClientToolAction( ... name="add", ... description="Add two numbers together", @@ -325,7 +325,7 @@ def chat( ... # Send result back ... response = client.agents.chat( ... agent_id="my_agent", - ... messages=ActionResult( + ... messages=ClientToolResult( ... action_id=call.action_id, ... content=f"The result is {result}" ... ), diff --git a/cognite/client/data_classes/agents/__init__.py b/cognite/client/data_classes/agents/__init__.py index 59b1de965c..2ee4644ecf 100644 --- a/cognite/client/data_classes/agents/__init__.py +++ b/cognite/client/data_classes/agents/__init__.py @@ -31,12 +31,14 @@ AgentReasoningItem, ClientToolAction, ClientToolCall, + ClientToolResult, Message, MessageContent, MessageList, TextContent, UnknownAction, UnknownActionCall, + UnknownActionResult, UnknownContent, ) @@ -61,6 +63,7 @@ "AskDocumentAgentToolUpsert", "ClientToolAction", "ClientToolCall", + "ClientToolResult", "DataModelInfo", "InstanceSpaces", "Message", @@ -76,6 +79,7 @@ "TextContent", "UnknownAction", "UnknownActionCall", + "UnknownActionResult", "UnknownAgentTool", "UnknownAgentToolUpsert", "UnknownContent", diff --git a/cognite/client/data_classes/agents/chat.py b/cognite/client/data_classes/agents/chat.py index b8999505ff..90568b96d7 100644 --- a/cognite/client/data_classes/agents/chat.py +++ b/cognite/client/data_classes/agents/chat.py @@ -313,27 +313,50 @@ class MessageList(CogniteResourceList[Message]): @dataclass -class ActionResult(CogniteResource): - """Result of executing a client action, for sending back to the agent. +class ActionResult(CogniteResource, ABC): + """Base class for action execution results.""" + + _type: ClassVar[str] + action_id: str + role: Literal["action"] = "action" + + @classmethod + def _load(cls, data: dict[str, Any], cognite_client: CogniteClient | None = None) -> ActionResult: + """Dispatch to the correct concrete action result class based on `type`.""" + action_type = data.get("type", "") + action_class = _ACTION_RESULT_CLS_BY_TYPE.get(action_type, UnknownActionResult) + return action_class._load_result(data, cognite_client) + + @abstractmethod + def dump(self, camel_case: bool = True) -> dict[str, Any]: + """Dump the action result to a dictionary.""" + ... + + @classmethod + @abstractmethod + def _load_result(cls, data: dict[str, Any], cognite_client: CogniteClient | None = None) -> ActionResult: + """Create a concrete action result instance from raw data.""" + ... + + +@dataclass +class ClientToolResult(ActionResult): + """Result of executing a client tool, for sending back to the agent. Args: action_id (str): The ID of the action being responded to. content (str | MessageContent): The result of executing the action. - action_type (str): The type of action (e.g., "clientTool"). Defaults to "clientTool". data (list[Any] | None): Optional structured data. """ - action_id: str - content: MessageContent - action_type: str = "clientTool" - data: list[Any] | None = None - role: Literal["action"] = "action" + _type: ClassVar[str] = "clientTool" + content: MessageContent = field(init=False) + data: list[Any] | None = field(init=False, default=None) def __init__( self, action_id: str, content: str | MessageContent, - action_type: str = "clientTool", data: list[Any] | None = None, ) -> None: self.action_id = action_id @@ -341,30 +364,63 @@ def __init__( self.content = TextContent(text=content) else: self.content = content - self.action_type = action_type self.data = data self.role = "action" def dump(self, camel_case: bool = True) -> dict[str, Any]: return { "role": self.role, - "type": self.action_type, + "type": self._type, "actionId" if camel_case else "action_id": self.action_id, "content": self.content.dump(camel_case=camel_case), "data": self.data or [], } @classmethod - def _load(cls, data: dict[str, Any], cognite_client: CogniteClient | None = None) -> ActionResult: + def _load_result(cls, data: dict[str, Any], cognite_client: CogniteClient | None = None) -> ClientToolResult: content = MessageContent._load(data["content"]) return cls( action_id=data["actionId"], content=content, - action_type=data.get("type", "clientTool"), data=data.get("data"), ) +@dataclass +class UnknownActionResult(ActionResult): + """Unknown action result type for forward compatibility. + + Args: + action_id (str): The ID of the action being responded to. + type (str): The action result type. + data (dict[str, Any]): The raw action result data. + """ + + type: str = "" + data: dict[str, Any] = field(default_factory=dict) + + def dump(self, camel_case: bool = True) -> dict[str, Any]: + result = self.data.copy() + result["role"] = self.role + result["type"] = self.type + result["actionId" if camel_case else "action_id"] = self.action_id + return result + + @classmethod + def _load_result(cls, data: dict[str, Any], cognite_client: CogniteClient | None = None) -> UnknownActionResult: + action_type = data.get("type", "") + action_id = data.get("actionId", "") + return cls(action_id=action_id, data=data, type=action_type) + + +# Build the mapping AFTER concrete classes are defined +_ACTION_RESULT_CLS_BY_TYPE: dict[str, type[ActionResult]] = { + subclass._type: subclass # type: ignore[type-abstract] + for subclass in ActionResult.__subclasses__() + if hasattr(subclass, "_type") and not getattr(subclass, "__abstractmethods__", None) +} + + @dataclass class AgentDataItem(CogniteObject): """Data item in agent response. From 55efb96c8f55034dbcd237f5fb76aecd1681a140 Mon Sep 17 00:00:00 2001 From: Kelvin Sundli Date: Tue, 30 Sep 2025 20:50:05 -0700 Subject: [PATCH 3/5] test(agents): add comprehensive unit tests for actions feature MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add 13 unit tests covering: - ClientToolAction serialization/deserialization - ClientToolCall with JSON argument parsing - ClientToolResult creation and serialization - Polymorphic dispatch for unknown action types - Chat API with actions parameter - Chat API with action result messages - AgentChatResponse.action_calls property All tests passing ✅ --- .../test_api/test_agents_actions.py | 316 ++++++++++++++++++ 1 file changed, 316 insertions(+) create mode 100644 tests/tests_unit/test_api/test_agents_actions.py diff --git a/tests/tests_unit/test_api/test_agents_actions.py b/tests/tests_unit/test_api/test_agents_actions.py new file mode 100644 index 0000000000..75929a0475 --- /dev/null +++ b/tests/tests_unit/test_api/test_agents_actions.py @@ -0,0 +1,316 @@ +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +from cognite.client import CogniteClient +from cognite.client.data_classes.agents import Message +from cognite.client.data_classes.agents.chat import ( + ActionCall, + AgentChatResponse, + ClientToolAction, + ClientToolCall, + ClientToolResult, + TextContent, + UnknownActionCall, +) + + +@pytest.fixture +def action_call_response_body() -> dict: + """Response body when agent requests an action.""" + return { + "agentId": "my_agent", + "agentExternalId": "my_agent", + "response": { + "cursor": "cursor_12345", + "messages": [ + { + "content": {"text": "", "type": "text"}, + "actions": [ + { + "type": "clientTool", + "actionId": "call_abc123", + "clientTool": { + "name": "add", + "arguments": '{"a": 42, "b": 58}', + }, + } + ], + "role": "agent", + } + ], + "type": "result", + }, + } + + +@pytest.fixture +def final_response_body() -> dict: + """Final response after action execution.""" + return { + "agentId": "my_agent", + "agentExternalId": "my_agent", + "response": { + "cursor": "cursor_67890", + "messages": [ + { + "content": {"text": "The result is 100.", "type": "text"}, + "role": "agent", + } + ], + "type": "result", + }, + } + + +class TestClientToolAction: + def test_dump(self) -> None: + action = ClientToolAction( + name="add", + description="Add two numbers", + parameters={ + "type": "object", + "properties": { + "a": {"type": "number"}, + "b": {"type": "number"}, + }, + "required": ["a", "b"], + }, + ) + + dumped = action.dump(camel_case=True) + + assert dumped["type"] == "clientTool" + assert dumped["clientTool"]["name"] == "add" + assert dumped["clientTool"]["description"] == "Add two numbers" + assert dumped["clientTool"]["parameters"]["type"] == "object" + assert "a" in dumped["clientTool"]["parameters"]["properties"] + + def test_load(self) -> None: + data = { + "type": "clientTool", + "clientTool": { + "name": "multiply", + "description": "Multiply two numbers", + "parameters": { + "type": "object", + "properties": { + "x": {"type": "number"}, + "y": {"type": "number"}, + }, + }, + }, + } + + action = ClientToolAction._load_action(data) + + assert isinstance(action, ClientToolAction) + assert action.name == "multiply" + assert action.description == "Multiply two numbers" + assert action.parameters["type"] == "object" + + +class TestClientToolCall: + def test_load_parses_arguments_json(self) -> None: + data = { + "type": "clientTool", + "actionId": "call_123", + "clientTool": { + "name": "add", + "arguments": '{"a": 10, "b": 20}', + }, + } + + call = ClientToolCall._load_call(data) + + assert isinstance(call, ClientToolCall) + assert call.action_id == "call_123" + assert call.name == "add" + assert call.arguments == {"a": 10, "b": 20} + + def test_dump_serializes_arguments_to_json(self) -> None: + call = ClientToolCall( + action_id="call_456", + name="multiply", + arguments={"x": 5, "y": 7}, + ) + + dumped = call.dump(camel_case=True) + + assert dumped["type"] == "clientTool" + assert dumped["actionId"] == "call_456" + assert dumped["clientTool"]["name"] == "multiply" + # Arguments should be JSON string + import json + + assert json.loads(dumped["clientTool"]["arguments"]) == {"x": 5, "y": 7} + + +class TestActionCallPolymorphism: + def test_unknown_action_call_loaded_for_unknown_type(self) -> None: + data = { + "type": "unknownActionType", + "actionId": "call_999", + "someField": "someValue", + } + + call = ActionCall._load(data) + + assert isinstance(call, UnknownActionCall) + assert call.action_id == "call_999" + assert call.type == "unknownActionType" + + +class TestClientToolResult: + def test_init_with_string_content(self) -> None: + result = ClientToolResult( + action_id="call_123", + content="The result is 42", + ) + + assert result.action_id == "call_123" + assert isinstance(result.content, TextContent) + assert result.content.text == "The result is 42" + assert result.role == "action" + + def test_init_with_message_content(self) -> None: + text_content = TextContent(text="Result: 100") + result = ClientToolResult( + action_id="call_456", + content=text_content, + ) + + assert result.action_id == "call_456" + assert result.content is text_content + assert result.content.text == "Result: 100" + + def test_dump(self) -> None: + result = ClientToolResult( + action_id="call_789", + content="Success", + data=[{"key": "value"}], + ) + + dumped = result.dump(camel_case=True) + + assert dumped["role"] == "action" + assert dumped["type"] == "clientTool" + assert dumped["actionId"] == "call_789" + assert dumped["content"]["text"] == "Success" + assert dumped["data"] == [{"key": "value"}] + + def test_load(self) -> None: + data = { + "type": "clientTool", + "actionId": "call_abc", + "role": "action", + "content": {"text": "Done", "type": "text"}, + "data": [], + } + + result = ClientToolResult._load_result(data) + + assert isinstance(result, ClientToolResult) + assert result.action_id == "call_abc" + assert result.content.text == "Done" + + +class TestChatWithActions: + def test_chat_with_actions_parameter(self, cognite_client: CogniteClient, action_call_response_body: dict) -> None: + cognite_client.agents._post = MagicMock(return_value=MagicMock(json=lambda: action_call_response_body)) + + add_action = ClientToolAction( + name="add", + description="Add two numbers", + parameters={ + "type": "object", + "properties": { + "a": {"type": "number"}, + "b": {"type": "number"}, + }, + "required": ["a", "b"], + }, + ) + + response = cognite_client.agents.chat( + agent_id="my_agent", + messages=Message("What is 42 plus 58?"), + actions=[add_action], + ) + + # Verify actions were sent in request + call_args = cognite_client.agents._post.call_args + assert "actions" in call_args[1]["json"] + assert len(call_args[1]["json"]["actions"]) == 1 + assert call_args[1]["json"]["actions"][0]["type"] == "clientTool" + assert call_args[1]["json"]["actions"][0]["clientTool"]["name"] == "add" + + # Verify response + assert isinstance(response, AgentChatResponse) + assert response.action_calls is not None + assert len(response.action_calls) == 1 + + def test_chat_with_action_result_message(self, cognite_client: CogniteClient, final_response_body: dict) -> None: + cognite_client.agents._post = MagicMock(return_value=MagicMock(json=lambda: final_response_body)) + + add_action = ClientToolAction( + name="add", + description="Add two numbers", + parameters={"type": "object", "properties": {"a": {"type": "number"}, "b": {"type": "number"}}}, + ) + + result = ClientToolResult( + action_id="call_abc123", + content="The result is 100", + ) + + response = cognite_client.agents.chat( + agent_id="my_agent", + messages=result, + cursor="cursor_12345", + actions=[add_action], + ) + + # Verify action result was sent + call_args = cognite_client.agents._post.call_args + assert call_args[1]["json"]["cursor"] == "cursor_12345" + assert len(call_args[1]["json"]["messages"]) == 1 + msg = call_args[1]["json"]["messages"][0] + assert msg["role"] == "action" + assert msg["type"] == "clientTool" + assert msg["actionId"] == "call_abc123" + + # Verify response + assert isinstance(response, AgentChatResponse) + assert response.text == "The result is 100." + + +class TestActionCallsProperty: + def test_action_calls_property_returns_none_when_no_actions(self, cognite_client: CogniteClient) -> None: + response_body = { + "agentId": "my_agent", + "agentExternalId": "my_agent", + "response": { + "cursor": "cursor_123", + "messages": [{"content": {"text": "Hello", "type": "text"}, "role": "agent"}], + "type": "result", + }, + } + + response = AgentChatResponse._load(response_body, cognite_client=cognite_client) + + assert response.action_calls is None + + def test_action_calls_property_extracts_all_actions( + self, cognite_client: CogniteClient, action_call_response_body: dict + ) -> None: + response = AgentChatResponse._load(action_call_response_body, cognite_client=cognite_client) + + assert response.action_calls is not None + assert len(response.action_calls) == 1 + assert isinstance(response.action_calls[0], ClientToolCall) + assert response.action_calls[0].action_id == "call_abc123" + assert response.action_calls[0].name == "add" + assert response.action_calls[0].arguments == {"a": 42, "b": 58} From 1f6bc6a2dfcfb398636f3e2b7b2c7f3207a03e42 Mon Sep 17 00:00:00 2001 From: Kelvin Sundli Date: Tue, 30 Sep 2025 20:56:53 -0700 Subject: [PATCH 4/5] refactor(tests): remove redundant action tests, keep 6 essential MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Removed 7 redundant tests that were already covered by integration tests: - ClientToolAction dump/load (covered by chat integration tests) - ClientToolCall load (covered by response parsing) - ClientToolResult string content and dump (covered by chat tests) - action_calls property edge cases (covered by main integration tests) Kept 6 non-redundant tests: - JSON argument serialization (unique coverage) - Unknown action type polymorphism (unique coverage) - MessageContent initialization (unique coverage) - ClientToolResult deserialization (unique coverage) - Chat with actions parameter (core integration) - Chat with action result message (core integration) 13 tests → 6 tests (-54%) while maintaining full feature coverage. --- .../test_api/test_agents_actions.py | 119 ------------------ 1 file changed, 119 deletions(-) diff --git a/tests/tests_unit/test_api/test_agents_actions.py b/tests/tests_unit/test_api/test_agents_actions.py index 75929a0475..57d6c21b88 100644 --- a/tests/tests_unit/test_api/test_agents_actions.py +++ b/tests/tests_unit/test_api/test_agents_actions.py @@ -65,71 +65,7 @@ def final_response_body() -> dict: } -class TestClientToolAction: - def test_dump(self) -> None: - action = ClientToolAction( - name="add", - description="Add two numbers", - parameters={ - "type": "object", - "properties": { - "a": {"type": "number"}, - "b": {"type": "number"}, - }, - "required": ["a", "b"], - }, - ) - - dumped = action.dump(camel_case=True) - - assert dumped["type"] == "clientTool" - assert dumped["clientTool"]["name"] == "add" - assert dumped["clientTool"]["description"] == "Add two numbers" - assert dumped["clientTool"]["parameters"]["type"] == "object" - assert "a" in dumped["clientTool"]["parameters"]["properties"] - - def test_load(self) -> None: - data = { - "type": "clientTool", - "clientTool": { - "name": "multiply", - "description": "Multiply two numbers", - "parameters": { - "type": "object", - "properties": { - "x": {"type": "number"}, - "y": {"type": "number"}, - }, - }, - }, - } - - action = ClientToolAction._load_action(data) - - assert isinstance(action, ClientToolAction) - assert action.name == "multiply" - assert action.description == "Multiply two numbers" - assert action.parameters["type"] == "object" - - class TestClientToolCall: - def test_load_parses_arguments_json(self) -> None: - data = { - "type": "clientTool", - "actionId": "call_123", - "clientTool": { - "name": "add", - "arguments": '{"a": 10, "b": 20}', - }, - } - - call = ClientToolCall._load_call(data) - - assert isinstance(call, ClientToolCall) - assert call.action_id == "call_123" - assert call.name == "add" - assert call.arguments == {"a": 10, "b": 20} - def test_dump_serializes_arguments_to_json(self) -> None: call = ClientToolCall( action_id="call_456", @@ -164,17 +100,6 @@ def test_unknown_action_call_loaded_for_unknown_type(self) -> None: class TestClientToolResult: - def test_init_with_string_content(self) -> None: - result = ClientToolResult( - action_id="call_123", - content="The result is 42", - ) - - assert result.action_id == "call_123" - assert isinstance(result.content, TextContent) - assert result.content.text == "The result is 42" - assert result.role == "action" - def test_init_with_message_content(self) -> None: text_content = TextContent(text="Result: 100") result = ClientToolResult( @@ -186,21 +111,6 @@ def test_init_with_message_content(self) -> None: assert result.content is text_content assert result.content.text == "Result: 100" - def test_dump(self) -> None: - result = ClientToolResult( - action_id="call_789", - content="Success", - data=[{"key": "value"}], - ) - - dumped = result.dump(camel_case=True) - - assert dumped["role"] == "action" - assert dumped["type"] == "clientTool" - assert dumped["actionId"] == "call_789" - assert dumped["content"]["text"] == "Success" - assert dumped["data"] == [{"key": "value"}] - def test_load(self) -> None: data = { "type": "clientTool", @@ -285,32 +195,3 @@ def test_chat_with_action_result_message(self, cognite_client: CogniteClient, fi # Verify response assert isinstance(response, AgentChatResponse) assert response.text == "The result is 100." - - -class TestActionCallsProperty: - def test_action_calls_property_returns_none_when_no_actions(self, cognite_client: CogniteClient) -> None: - response_body = { - "agentId": "my_agent", - "agentExternalId": "my_agent", - "response": { - "cursor": "cursor_123", - "messages": [{"content": {"text": "Hello", "type": "text"}, "role": "agent"}], - "type": "result", - }, - } - - response = AgentChatResponse._load(response_body, cognite_client=cognite_client) - - assert response.action_calls is None - - def test_action_calls_property_extracts_all_actions( - self, cognite_client: CogniteClient, action_call_response_body: dict - ) -> None: - response = AgentChatResponse._load(action_call_response_body, cognite_client=cognite_client) - - assert response.action_calls is not None - assert len(response.action_calls) == 1 - assert isinstance(response.action_calls[0], ClientToolCall) - assert response.action_calls[0].action_id == "call_abc123" - assert response.action_calls[0].name == "add" - assert response.action_calls[0].arguments == {"a": 42, "b": 58} From 4ca658fab0670b52f4b15ef9be38ce4cad6e7a34 Mon Sep 17 00:00:00 2001 From: Kelvin Sundli Date: Tue, 30 Sep 2025 22:06:26 -0700 Subject: [PATCH 5/5] feat(agents): add ClientToolAction.from_function() with function introspection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement automatic generation of client tool actions from Python functions using type hints and docstrings, making it easier to define agent actions. Key Features: - Automatic JSON Schema generation from function signatures - Google-style docstring parsing for descriptions - Support for primitives (int, float, str, bool) and lists - Optional parameter detection (default=None) - Clear error messages for unsupported types Implementation: - Add function_introspection utility module in cognite/client/utils/ - Add ClientToolAction.from_function() classmethod - Extract introspection logic to reusable utilities following SDK patterns Benefits: - Reduces boilerplate: ~50 lines of manual schema → 1 line from_function() - Better DX: Type-safe, IDE-friendly function definitions - Maintainable: Introspection logic separate from agent code - Testable: 25 unit tests for utils, 15 integration tests for from_function() Example: def add(a: float, b: float) -> float: '''Add two numbers. Args: a: First number. b: Second number. ''' return a + b # Before: ~50 lines of manual JSON Schema # After: action = ClientToolAction.from_function(add) response = client.agents.chat( agent_id="my_agent", messages=Message("What is 42 + 58?"), actions=[action] ) Files: - cognite/client/utils/_function_introspection.py: New utility module - cognite/client/data_classes/agents/chat.py: Add from_function() method - tests/tests_unit/test_utils/test_function_introspection.py: 25 unit tests - tests/tests_unit/test_api/test_agents_actions.py: 15 integration tests Tests: All 61 tests passing (25 new utils tests, 21 existing tests unchanged) --- cognite/client/data_classes/agents/chat.py | 116 +++++++ .../client/utils/_function_introspection.py | 279 ++++++++++++++++ .../test_api/test_agents_actions.py | 297 ++++++++++++++++++ .../test_utils/test_function_introspection.py | 288 +++++++++++++++++ 4 files changed, 980 insertions(+) create mode 100644 cognite/client/utils/_function_introspection.py create mode 100644 tests/tests_unit/test_utils/test_function_introspection.py diff --git a/cognite/client/data_classes/agents/chat.py b/cognite/client/data_classes/agents/chat.py index 90568b96d7..90e3a2262a 100644 --- a/cognite/client/data_classes/agents/chat.py +++ b/cognite/client/data_classes/agents/chat.py @@ -1,11 +1,14 @@ from __future__ import annotations import json +import warnings from abc import ABC, abstractmethod +from collections.abc import Callable from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, ClassVar, Literal from cognite.client.data_classes._base import CogniteObject, CogniteResource, CogniteResourceList +from cognite.client.utils._function_introspection import function_to_json_schema from cognite.client.utils._text import convert_all_keys_to_camel_case if TYPE_CHECKING: @@ -145,6 +148,119 @@ def _load_action(cls, data: dict[str, Any], cognite_client: CogniteClient | None parameters=client_tool["parameters"], ) + @classmethod + def from_function(cls, func: Callable[..., Any], name: str | None = None) -> ClientToolAction: + """Create a ClientToolAction from a Python function with type hints and docstring. + + This method introspects a Python function to automatically generate an action definition + that can be used with agents. The function must have type hints for all parameters. + + **Supported Types:** + - Primitives: ``int``, ``float``, ``str``, ``bool`` + - Lists of primitives: ``list[int]``, ``list[float]``, ``list[str]``, ``list[bool]`` + + **Optional Parameters:** + Parameters with a default value of ``None`` are treated as optional and will not be + included in the ``required`` list of the JSON Schema. + + **Docstring Format:** + The function should have a Google-style docstring with: + - A short description (used as the action description) + - An ``Args:`` section with parameter descriptions + + Args: + func (Callable[..., Any]): The Python function to convert to an action. + name (str | None): Optional custom name for the action. If not provided, uses the function name. + + Returns: + ClientToolAction: An action that can be used in agent chat. + + Raises: + TypeError: If a parameter is missing a type annotation or has an unsupported type. + + Examples: + Basic usage with primitives:: + + def add(a: float, b: float) -> float: + \"\"\"Add two numbers. + + Args: + a: The first number. + b: The second number. + + Returns: + The sum of the two numbers. + \"\"\" + return a + b + + action = ClientToolAction.from_function(add) + + With optional parameters:: + + def greet(name: str, title: str = None) -> str: + \"\"\"Greet a person. + + Args: + name: The person's name. + title: Optional title (e.g., "Dr.", "Mr."). + + Returns: + A greeting message. + \"\"\" + if title: + return f"Hello, {title} {name}!" + return f"Hello, {name}!" + + action = ClientToolAction.from_function(greet) + # name is required, title is optional + + With list parameters:: + + def sum_numbers(numbers: list[float]) -> float: + \"\"\"Calculate the sum of a list of numbers. + + Args: + numbers: List of numbers to sum. + + Returns: + The sum of all numbers. + \"\"\" + return sum(numbers) + + action = ClientToolAction.from_function(sum_numbers) + + With custom name:: + + action = ClientToolAction.from_function(add, name="custom_add") + + Use in agent chat:: + + response = client.agents.chat( + agent_id="my_agent", + messages=Message("What is 42 plus 58?"), + actions=[ClientToolAction.from_function(add)] + ) + """ + # Get function name + action_name = name or func.__name__ + + # Generate JSON Schema from function + func_description, parameters_schema = function_to_json_schema(func) + + # Warn if no docstring (description will be function name) + if func_description == func.__name__: + warnings.warn( + f"Function '{func.__name__}' has no docstring, using function name as description", + UserWarning, + stacklevel=2, + ) + + return cls( + name=action_name, + description=func_description, + parameters=parameters_schema, + ) + @dataclass class UnknownAction(Action): diff --git a/cognite/client/utils/_function_introspection.py b/cognite/client/utils/_function_introspection.py new file mode 100644 index 0000000000..ebbc2ad959 --- /dev/null +++ b/cognite/client/utils/_function_introspection.py @@ -0,0 +1,279 @@ +"""Utilities for introspecting Python functions to generate JSON Schema. + +This module provides utilities for converting Python function signatures and docstrings +into JSON Schema definitions, primarily for use with agent function calling features. +""" + +from __future__ import annotations + +import inspect +import re +from collections.abc import Callable +from typing import Any, get_args, get_origin, get_type_hints + + +def parse_google_docstring(docstring: str | None) -> tuple[str, dict[str, str]]: + """Parse Google-style docstring to extract description and parameter descriptions. + + Args: + docstring (str | None): The function's docstring. + + Returns: + tuple[str, dict[str, str]]: Tuple of (function_description, param_descriptions_dict). + + Examples: + >>> def example(param1: str, param2: int) -> None: + ... '''Example function. + ... + ... Args: + ... param1: First parameter. + ... param2: Second parameter. + ... ''' + ... pass + >>> desc, params = parse_google_docstring(example.__doc__) + >>> desc + 'Example function.' + >>> params['param1'] + 'First parameter.' + """ + if not docstring: + return "", {} + + # Split into lines + lines = docstring.strip().split("\n") + + # Extract function description (everything before Args section) + description_lines = [] + param_descriptions = {} + in_args_section = False + current_param = None + current_param_desc_lines = [] + + for line in lines: + stripped = line.strip() + + # Check if we're entering Args section + if stripped in ("Args:", "Arguments:", "Parameters:"): + in_args_section = True + continue + + # Check if we're leaving Args section (Returns, Raises, etc.) + if in_args_section and stripped and stripped.endswith(":"): + # Save any pending parameter description + if current_param and current_param_desc_lines: + param_descriptions[current_param] = " ".join(current_param_desc_lines).strip() + break + + # If we're not in args section yet, check if we hit Returns/Raises (end of description) + if ( + not in_args_section + and stripped + and stripped.endswith(":") + and stripped + in ( + "Returns:", + "Return:", + "Raises:", + "Raises:", + "Yields:", + "Yield:", + "Note:", + "Notes:", + "Example:", + "Examples:", + ) + ): + break + + if in_args_section: + # Parse parameter line: "param_name: description" or "param_name (type): description" + param_match = re.match(r"^\s*(\w+)(?:\s*\([^)]+\))?\s*:\s*(.*)$", line) + if param_match: + # Save previous parameter if any + if current_param and current_param_desc_lines: + param_descriptions[current_param] = " ".join(current_param_desc_lines).strip() + + # Start new parameter + current_param = param_match.group(1) + current_param_desc_lines = [param_match.group(2)] if param_match.group(2) else [] + elif current_param and stripped: + # Continuation of current parameter description + current_param_desc_lines.append(stripped) + else: + # Part of function description + if stripped: + description_lines.append(stripped) + + # Save last parameter + if current_param and current_param_desc_lines: + param_descriptions[current_param] = " ".join(current_param_desc_lines).strip() + + # Join description lines + description = " ".join(description_lines).strip() + + return description, param_descriptions + + +def type_to_json_schema(param_type: type, param_name: str) -> dict[str, Any]: + """Convert Python type hint to JSON Schema type. + + Supports primitive types (int, float, str, bool) and lists of primitives. + + Args: + param_type (type): The type annotation. + param_name (str): The parameter name (for error messages). + + Returns: + dict[str, Any]: JSON Schema type object. + + Raises: + TypeError: If the type is not supported. + + Examples: + >>> type_to_json_schema(int, "count") + {'type': 'integer'} + >>> type_to_json_schema(list[str], "names") + {'type': 'array', 'items': {'type': 'string'}} + """ + # Handle primitives + if param_type is int: + return {"type": "integer"} + elif param_type is float: + return {"type": "number"} + elif param_type is str: + return {"type": "string"} + elif param_type is bool: + return {"type": "boolean"} + + # Check for bare list type (without item type annotation) + if param_type is list: + raise TypeError(f"Parameter '{param_name}' has type 'list' without item type. Use list[int], list[str], etc.") + + # Handle list types with item type annotation + origin = get_origin(param_type) + if origin is list: + args = get_args(param_type) + if not args: + raise TypeError( + f"Parameter '{param_name}' has type 'list' without item type. Use list[int], list[str], etc." + ) + + item_type = args[0] + # Only support primitives in lists + if item_type is int: + return {"type": "array", "items": {"type": "integer"}} + elif item_type is float: + return {"type": "array", "items": {"type": "number"}} + elif item_type is str: + return {"type": "array", "items": {"type": "string"}} + elif item_type is bool: + return {"type": "array", "items": {"type": "boolean"}} + else: + raise TypeError( + f"Parameter '{param_name}' has unsupported list item type '{item_type.__name__}'. " + f"Supported list types: list[int], list[float], list[str], list[bool]" + ) + + # Unsupported type + type_name = getattr(param_type, "__name__", str(param_type)) + raise TypeError( + f"Parameter '{param_name}' has unsupported type '{type_name}'. " + f"Supported types: int, float, str, bool, list[int], list[float], list[str], list[bool]" + ) + + +def function_to_json_schema(func: Callable[..., Any], description: str | None = None) -> tuple[str, dict[str, Any]]: + """Generate JSON Schema from a Python function's signature and docstring. + + This function introspects a Python function to extract its signature, type hints, + and docstring, then generates a JSON Schema that describes the function's parameters. + + Args: + func (Callable[..., Any]): The Python function to introspect. + description (str | None): Optional description override. If not provided, extracted from docstring. + + Returns: + tuple[str, dict[str, Any]]: Tuple of (function_description, parameters_schema). + + Raises: + TypeError: If a parameter is missing a type annotation or has an unsupported type. + + Examples: + >>> def add(a: float, b: float) -> float: + ... '''Add two numbers. + ... + ... Args: + ... a: First number. + ... b: Second number. + ... ''' + ... return a + b + >>> desc, schema = function_to_json_schema(add) + >>> desc + 'Add two numbers.' + >>> schema['properties']['a'] + {'type': 'number', 'description': 'First number.'} + >>> schema['required'] + ['a', 'b'] + """ + # Get function signature + sig = inspect.signature(func) + + # Get type hints + try: + type_hints = get_type_hints(func) + except Exception as e: + raise TypeError(f"Failed to get type hints for function '{func.__name__}': {e}") from e + + # Parse docstring if description not provided + if description is None: + description, param_descriptions = parse_google_docstring(func.__doc__) + # Use function name as description if no docstring + if not description: + description = func.__name__ + else: + _, param_descriptions = parse_google_docstring(func.__doc__) + + # Build JSON Schema + properties = {} + required = [] + + for param_name, param in sig.parameters.items(): + # Skip *args, **kwargs, self, cls + if param.kind in ( + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + ): + continue + + # Check if parameter has type hint + if param_name not in type_hints: + raise TypeError(f"Parameter '{param_name}' is missing type annotation") + + param_type = type_hints[param_name] + + # Convert type to JSON Schema + schema_type = type_to_json_schema(param_type, param_name) + + # Add description if available + if param_name in param_descriptions: + schema_type["description"] = param_descriptions[param_name] + + properties[param_name] = schema_type + + # Check if required (no default or default is not None for Optional types) + if param.default is inspect.Parameter.empty: + required.append(param_name) + elif param.default is not None: + # Has a non-None default, so it's required but with a default + required.append(param_name) + # If default is None, it's optional (not in required list) + + # Build final parameters schema + parameters_schema = { + "type": "object", + "properties": properties, + } + if required: + parameters_schema["required"] = required + + return description, parameters_schema diff --git a/tests/tests_unit/test_api/test_agents_actions.py b/tests/tests_unit/test_api/test_agents_actions.py index 57d6c21b88..10e34b8110 100644 --- a/tests/tests_unit/test_api/test_agents_actions.py +++ b/tests/tests_unit/test_api/test_agents_actions.py @@ -195,3 +195,300 @@ def test_chat_with_action_result_message(self, cognite_client: CogniteClient, fi # Verify response assert isinstance(response, AgentChatResponse) assert response.text == "The result is 100." + + +class TestFromFunction: + """Test ClientToolAction.from_function() method.""" + + def test_basic_function_with_primitives(self) -> None: + """Test basic function with primitive types.""" + + def add(a: float, b: float) -> float: + """Add two numbers. + + Args: + a: The first number. + b: The second number. + + Returns: + The sum. + """ + return a + b + + action = ClientToolAction.from_function(add) + + assert action.name == "add" + assert action.description == "Add two numbers." + assert action.parameters == { + "type": "object", + "properties": { + "a": {"type": "number", "description": "The first number."}, + "b": {"type": "number", "description": "The second number."}, + }, + "required": ["a", "b"], + } + + def test_all_primitive_types(self) -> None: + """Test all supported primitive types.""" + + def example( + name: str, + count: int, + value: float, + enabled: bool, + ) -> None: + """Example function. + + Args: + name: A string parameter. + count: An integer parameter. + value: A float parameter. + enabled: A boolean parameter. + """ + pass + + action = ClientToolAction.from_function(example) + + assert action.parameters["properties"]["name"] == { + "type": "string", + "description": "A string parameter.", + } + assert action.parameters["properties"]["count"] == { + "type": "integer", + "description": "An integer parameter.", + } + assert action.parameters["properties"]["value"] == { + "type": "number", + "description": "A float parameter.", + } + assert action.parameters["properties"]["enabled"] == { + "type": "boolean", + "description": "A boolean parameter.", + } + assert action.parameters["required"] == ["name", "count", "value", "enabled"] + + def test_list_types(self) -> None: + """Test list parameter types.""" + + def process_lists( + names: list[str], + scores: list[float], + counts: list[int], + flags: list[bool], + ) -> None: + """Process various lists. + + Args: + names: List of names. + scores: List of scores. + counts: List of counts. + flags: List of flags. + """ + pass + + action = ClientToolAction.from_function(process_lists) + + assert action.parameters["properties"]["names"] == { + "type": "array", + "items": {"type": "string"}, + "description": "List of names.", + } + assert action.parameters["properties"]["scores"] == { + "type": "array", + "items": {"type": "number"}, + "description": "List of scores.", + } + assert action.parameters["properties"]["counts"] == { + "type": "array", + "items": {"type": "integer"}, + "description": "List of counts.", + } + assert action.parameters["properties"]["flags"] == { + "type": "array", + "items": {"type": "boolean"}, + "description": "List of flags.", + } + + def test_optional_parameters(self) -> None: + """Test optional parameters (default=None).""" + + def greet(name: str, title: str | None = None) -> str: + """Greet a person. + + Args: + name: The person's name. + title: Optional title. + """ + return f"Hello, {title} {name}!" if title else f"Hello, {name}!" + + action = ClientToolAction.from_function(greet) + + assert "name" in action.parameters["required"] + assert "title" not in action.parameters["required"] + assert "title" in action.parameters["properties"] + + def test_custom_name(self) -> None: + """Test custom action name override.""" + + def add(a: float, b: float) -> float: + """Add two numbers.""" + return a + b + + action = ClientToolAction.from_function(add, name="custom_add") + + assert action.name == "custom_add" + + def test_no_docstring_uses_function_name(self) -> None: + """Test that missing docstring uses function name with warning.""" + + def add(a: float, b: float) -> float: + return a + b + + with pytest.warns(UserWarning, match="has no docstring"): + action = ClientToolAction.from_function(add) + + assert action.name == "add" + assert action.description == "add" + + def test_multiline_param_description(self) -> None: + """Test parameter descriptions that span multiple lines.""" + + def example(param: str) -> None: + """Example function. + + Args: + param: This is a long description + that spans multiple lines + in the docstring. + """ + pass + + action = ClientToolAction.from_function(example) + + assert ( + action.parameters["properties"]["param"]["description"] + == "This is a long description that spans multiple lines in the docstring." + ) + + def test_multiline_function_description(self) -> None: + """Test function description that spans multiple lines.""" + + def example(param: str) -> None: + """This is a function description + that spans multiple lines + before the Args section. + + Args: + param: A parameter. + """ + pass + + action = ClientToolAction.from_function(example) + + assert action.description == "This is a function description that spans multiple lines before the Args section." + + def test_missing_type_annotation_raises_error(self) -> None: + """Test that missing type annotation raises TypeError.""" + + def bad(param) -> None: + """Bad function.""" + pass + + with pytest.raises(TypeError, match="missing type annotation"): + ClientToolAction.from_function(bad) + + def test_unsupported_type_raises_error(self) -> None: + """Test that unsupported types raise TypeError.""" + + def bad(data: dict) -> None: + """Bad function. + + Args: + data: A dict parameter. + """ + pass + + with pytest.raises(TypeError, match="unsupported type"): + ClientToolAction.from_function(bad) + + def test_unsupported_list_item_type_raises_error(self) -> None: + """Test that unsupported list item types raise TypeError.""" + + def bad(items: list[dict]) -> None: + """Bad function. + + Args: + items: A list of dicts. + """ + pass + + with pytest.raises(TypeError, match="unsupported list item type"): + ClientToolAction.from_function(bad) + + def test_list_without_item_type_raises_error(self) -> None: + """Test that bare list type raises TypeError.""" + + def bad(items: list) -> None: + """Bad function. + + Args: + items: A list. + """ + pass + + with pytest.raises(TypeError, match="without item type"): + ClientToolAction.from_function(bad) + + def test_function_with_no_parameters(self) -> None: + """Test function with no parameters.""" + + def get_status() -> str: + """Get current status. + + Returns: + The status string. + """ + return "OK" + + action = ClientToolAction.from_function(get_status) + + assert action.name == "get_status" + assert action.description == "Get current status." + assert action.parameters == { + "type": "object", + "properties": {}, + } + assert "required" not in action.parameters + + def test_param_without_description_in_docstring(self) -> None: + """Test parameter without description in docstring.""" + + def example(param: str) -> None: + """Example function.""" + pass + + action = ClientToolAction.from_function(example) + + # Parameter should exist but without description + assert "param" in action.parameters["properties"] + assert "description" not in action.parameters["properties"]["param"] + + def test_dump_and_use_in_chat(self) -> None: + """Test that generated action can be dumped and used in chat.""" + + def add(a: float, b: float) -> float: + """Add two numbers. + + Args: + a: First number. + b: Second number. + """ + return a + b + + action = ClientToolAction.from_function(add) + dumped = action.dump() + + assert dumped["type"] == "clientTool" + assert dumped["clientTool"]["name"] == "add" + assert dumped["clientTool"]["description"] == "Add two numbers." + assert dumped["clientTool"]["parameters"]["type"] == "object" diff --git a/tests/tests_unit/test_utils/test_function_introspection.py b/tests/tests_unit/test_utils/test_function_introspection.py new file mode 100644 index 0000000000..0cad30bbd5 --- /dev/null +++ b/tests/tests_unit/test_utils/test_function_introspection.py @@ -0,0 +1,288 @@ +"""Unit tests for function introspection utilities.""" + +from __future__ import annotations + +import pytest + +from cognite.client.utils._function_introspection import ( + function_to_json_schema, + parse_google_docstring, + type_to_json_schema, +) + + +class TestParseGoogleDocstring: + """Tests for parse_google_docstring function.""" + + def test_basic_docstring(self) -> None: + """Test parsing a basic docstring.""" + docstring = """Short description. + + Args: + param1: First parameter. + param2: Second parameter. + + Returns: + Some value. + """ + desc, params = parse_google_docstring(docstring) + + assert desc == "Short description." + assert params["param1"] == "First parameter." + assert params["param2"] == "Second parameter." + + def test_multiline_description(self) -> None: + """Test multiline function description.""" + docstring = """This is a long description + that spans multiple lines + before the Args section. + + Args: + param: A parameter. + """ + desc, params = parse_google_docstring(docstring) + + assert desc == "This is a long description that spans multiple lines before the Args section." + assert params["param"] == "A parameter." + + def test_multiline_param_description(self) -> None: + """Test multiline parameter descriptions.""" + docstring = """Function description. + + Args: + param: This is a long description + that spans multiple lines + in the docstring. + """ + desc, params = parse_google_docstring(docstring) + + assert desc == "Function description." + assert params["param"] == "This is a long description that spans multiple lines in the docstring." + + def test_param_with_type_annotation(self) -> None: + """Test parameter with type annotation in docstring.""" + docstring = """Function description. + + Args: + param (str): String parameter. + count (int): Integer parameter. + """ + desc, params = parse_google_docstring(docstring) + + assert params["param"] == "String parameter." + assert params["count"] == "Integer parameter." + + def test_no_docstring(self) -> None: + """Test with None docstring.""" + desc, params = parse_google_docstring(None) + + assert desc == "" + assert params == {} + + def test_no_args_section(self) -> None: + """Test docstring without Args section.""" + docstring = """Just a description. + + Returns: + Some value. + """ + desc, params = parse_google_docstring(docstring) + + assert desc == "Just a description." + assert params == {} + + +class TestTypeToJsonSchema: + """Tests for type_to_json_schema function.""" + + def test_primitive_int(self) -> None: + """Test int type.""" + schema = type_to_json_schema(int, "param") + assert schema == {"type": "integer"} + + def test_primitive_float(self) -> None: + """Test float type.""" + schema = type_to_json_schema(float, "param") + assert schema == {"type": "number"} + + def test_primitive_str(self) -> None: + """Test str type.""" + schema = type_to_json_schema(str, "param") + assert schema == {"type": "string"} + + def test_primitive_bool(self) -> None: + """Test bool type.""" + schema = type_to_json_schema(bool, "param") + assert schema == {"type": "boolean"} + + def test_list_int(self) -> None: + """Test list[int] type.""" + schema = type_to_json_schema(list[int], "param") + assert schema == {"type": "array", "items": {"type": "integer"}} + + def test_list_float(self) -> None: + """Test list[float] type.""" + schema = type_to_json_schema(list[float], "param") + assert schema == {"type": "array", "items": {"type": "number"}} + + def test_list_str(self) -> None: + """Test list[str] type.""" + schema = type_to_json_schema(list[str], "param") + assert schema == {"type": "array", "items": {"type": "string"}} + + def test_list_bool(self) -> None: + """Test list[bool] type.""" + schema = type_to_json_schema(list[bool], "param") + assert schema == {"type": "array", "items": {"type": "boolean"}} + + def test_bare_list_raises_error(self) -> None: + """Test that bare list type raises TypeError.""" + with pytest.raises(TypeError, match="without item type"): + type_to_json_schema(list, "param") + + def test_unsupported_type_raises_error(self) -> None: + """Test that unsupported types raise TypeError.""" + with pytest.raises(TypeError, match="unsupported type"): + type_to_json_schema(dict, "param") + + def test_unsupported_list_item_type_raises_error(self) -> None: + """Test that unsupported list item types raise TypeError.""" + with pytest.raises(TypeError, match="unsupported list item type"): + type_to_json_schema(list[dict], "param") + + +class TestFunctionToJsonSchema: + """Tests for function_to_json_schema function.""" + + def test_basic_function(self) -> None: + """Test basic function with primitives.""" + + def add(a: float, b: float) -> float: + """Add two numbers. + + Args: + a: First number. + b: Second number. + """ + return a + b + + desc, schema = function_to_json_schema(add) + + assert desc == "Add two numbers." + assert schema == { + "type": "object", + "properties": { + "a": {"type": "number", "description": "First number."}, + "b": {"type": "number", "description": "Second number."}, + }, + "required": ["a", "b"], + } + + def test_all_primitive_types(self) -> None: + """Test all supported primitive types.""" + + def example(name: str, count: int, value: float, enabled: bool) -> None: + """Example function. + + Args: + name: A string. + count: An integer. + value: A float. + enabled: A boolean. + """ + pass + + desc, schema = function_to_json_schema(example) + + assert schema["properties"]["name"]["type"] == "string" + assert schema["properties"]["count"]["type"] == "integer" + assert schema["properties"]["value"]["type"] == "number" + assert schema["properties"]["enabled"]["type"] == "boolean" + assert schema["required"] == ["name", "count", "value", "enabled"] + + def test_optional_parameters(self) -> None: + """Test optional parameters (default=None).""" + + def greet(name: str, title: str | None = None) -> str: + """Greet a person. + + Args: + name: The person's name. + title: Optional title. + """ + return f"Hello, {title} {name}!" if title else f"Hello, {name}!" + + desc, schema = function_to_json_schema(greet) + + assert "name" in schema["required"] + assert "title" not in schema["required"] + assert "title" in schema["properties"] + + def test_function_with_no_parameters(self) -> None: + """Test function with no parameters.""" + + def get_status() -> str: + """Get current status. + + Returns: + The status. + """ + return "OK" + + desc, schema = function_to_json_schema(get_status) + + assert desc == "Get current status." + assert schema == {"type": "object", "properties": {}} + assert "required" not in schema + + def test_function_with_no_docstring(self) -> None: + """Test function with no docstring.""" + + def add(a: float, b: float) -> float: + return a + b + + desc, schema = function_to_json_schema(add) + + assert desc == "add" + assert "a" in schema["properties"] + assert "b" in schema["properties"] + + def test_custom_description_override(self) -> None: + """Test custom description override.""" + + def add(a: float, b: float) -> float: + """Add two numbers.""" + return a + b + + desc, schema = function_to_json_schema(add, description="Custom description") + + assert desc == "Custom description" + + def test_missing_type_annotation_raises_error(self) -> None: + """Test that missing type annotation raises TypeError.""" + + def bad(param) -> None: + """Bad function.""" + pass + + with pytest.raises(TypeError, match="missing type annotation"): + function_to_json_schema(bad) + + def test_list_types(self) -> None: + """Test list parameter types.""" + + def process_lists(names: list[str], scores: list[float]) -> None: + """Process lists. + + Args: + names: List of names. + scores: List of scores. + """ + pass + + desc, schema = function_to_json_schema(process_lists) + + assert schema["properties"]["names"]["type"] == "array" + assert schema["properties"]["names"]["items"]["type"] == "string" + assert schema["properties"]["scores"]["type"] == "array" + assert schema["properties"]["scores"]["items"]["type"] == "number"