-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
123 lines (102 loc) · 4.02 KB
/
main.py
File metadata and controls
123 lines (102 loc) · 4.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import argparse
from dotenv import load_dotenv
from google import genai
from google.genai import types
from configs import MODEL_NAME, SYSTEM_PROMPT
from call_function import available_functions
from functions.get_files_info import get_files_info
from functions.get_file_content import get_file_content
from functions.run_python_file import run_python_file
from functions.write_file import write_file
def main():
iteration = 0
max_iterations = 20
parser = argparse.ArgumentParser(description="Chatbot")
parser.add_argument("prompt", type=str, help="User prompt")
parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
args = parser.parse_args()
# Now we can access `args.prompt`
load_dotenv()
api_key = os.environ.get("GEMINI_API_KEY")
if api_key is None:
raise RuntimeError("API Key is None!")
client = genai.Client(api_key=api_key)
config=types.GenerateContentConfig(
tools=[available_functions], system_instruction=SYSTEM_PROMPT
)
prompt = args.prompt
messages = [types.Content(role="user", parts=[types.Part(text=prompt)])]
while iteration < max_iterations:
iteration += 1
function_call_count = 0
content_object = client.models.generate_content(
model=MODEL_NAME,
contents=messages,
config=config
)
content_metadata = content_object.usage_metadata
if content_metadata is None:
raise RuntimeError("API request failed.")
if content_object.candidates is not None:
for candidate in content_object.candidates:
messages.append(candidate.content)
if args.verbose:
print(f"User prompt: {prompt}")
print(f"Prompt tokens: {content_metadata.prompt_token_count}")
print(f"Response tokens: {content_metadata.candidates_token_count}")
content_function_calls = content_object.function_calls
call_responses = []
if content_function_calls:
for function_call in content_function_calls:
function_call_count += 1
call_result = call_function(function_call, verbose=args.verbose)
# result = call_result.parts[0].function_response.response
if call_result.parts[0].function_response.response is None:
raise Exception(f"Error getting function '{function_call.name}' result")
call_responses.append(call_result.parts[0].function_response.response)
if args.verbose:
print(f"-> {call_result.parts[0].function_response.response}")
messages.append(call_result)
else:
if function_call_count == 0:
print("Final Response:")
else:
print("Response:")
print(content_object.text)
if function_call_count == 0:
break
def call_function(function_call, verbose=False):
if verbose:
print(f"Calling function: {function_call.name}({function_call.args})")
else:
print(f" - Calling function: {function_call.name}")
functions = {
'get_files_info': get_files_info,
'get_file_content': get_file_content,
'run_python_file': run_python_file,
'write_file': write_file,
}
called_function = functions[function_call.name]
if function_call.name not in functions:
return types.Content(
role="tool",
parts=[
types.Part.from_function_response(
name=function_call.name,
response={"error": f"Unknown function: {function_call.name}"},
)
],
)
function_result = called_function('./calculator', **function_call.args)
return types.Content(
role="tool",
parts=[
types.Part.from_function_response(
name=function_call.name,
response={"result": function_result},
)
],
)
if __name__ == "__main__":
main()