-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtools.py
More file actions
72 lines (58 loc) · 2.21 KB
/
tools.py
File metadata and controls
72 lines (58 loc) · 2.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from typing import Callable
from typing import Any, Dict, get_origin, get_args
from pydantic import BaseModel
from typing import Type
ToolRegistry = {}
def get_input_schema_dict(input_schema:Type[BaseModel]) -> dict:
input_schema_dict = input_schema.model_json_schema()
return input_schema_dict
class ToolResult:
result: str
content: str
def __init__(self, result:str, content:str):
self.result = result
self.content = content
def get_result(self):
return self.result
def get_content(self):
return self.content
class Tool:
name:str
description:str
input_schema:dict
func:Callable
def __init__(self, name:str, description:str, input_schema:dict):
self.name = name
self.description = description
self.input_schema = input_schema
def __call__(self, *args, **kwargs) -> ToolResult:
return self.func(self.tool_class(), *args, **kwargs)
def to_dict(self):
return {
"name": self.name,
"description": self.description,
"input_schema": self.input_schema
}
def __str__(self):
return f"Tool(name={self.name}, description={self.description}, input_schema={self.input_schema})"
def register_tool(cls):
if not hasattr(cls, "name"):
raise Exception(f"Tool {cls.__name__} does not have a name")
if not hasattr(cls, "description"):
raise Exception(f"Tool {cls.__name__} does not have a description")
if not hasattr(cls, "input_schema"):
raise Exception(f"Tool {cls.__name__} does not have an input schema")
if not hasattr(cls, "_run"):
raise Exception(f"Tool {cls.__name__} does not have a _run method")
input_schema = get_input_schema_dict(cls.input_schema)
tool_instance = Tool(name=cls.name, description=cls.description, input_schema=input_schema)
tool_instance.func = cls._run
tool_instance.tool_class = cls
ToolRegistry[cls.name] = tool_instance
return cls
def get_tool(name:str) -> Tool:
if name not in ToolRegistry:
raise Exception(f"Tool {name} not found, please register the tool first")
return ToolRegistry[name]
def get_tool_names() -> list[str]:
return ToolRegistry.keys()